diff --git a/Dockerfile.xpu b/Dockerfile.xpu new file mode 100644 index 0000000..bb4d2dd --- /dev/null +++ b/Dockerfile.xpu @@ -0,0 +1,68 @@ +# Base image +FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 + +ARG http_proxy +ARG https_proxy + +ENV DEBIAN_FRONTEND=noninteractive +ENV CONDA_DIR=/opt/conda + +# Install dependencies +RUN apt-get update && apt-get install -y \ + wget \ + curl \ + bash \ + git \ + vim \ + ca-certificates \ + binutils \ + cmake \ + g++ \ + && rm -rf /var/lib/apt/lists/* + +# Install Miniforge +RUN wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O /tmp/miniforge.sh && \ + bash /tmp/miniforge.sh -b -p $CONDA_DIR && \ + rm /tmp/miniforge.sh && \ + $CONDA_DIR/bin/conda clean -afy + +# Add conda to PATH +ENV PATH=$CONDA_DIR/bin:$PATH + +RUN bash -c "\ + source /opt/conda/etc/profile.d/conda.sh && \ + conda create --name ktransformers python=3.11 -y && \ + conda activate ktransformers && \ + conda env list && \ + conda install -c conda-forge libstdcxx-ng -y && \ + strings \$(find /opt/conda/envs/ktransformers/lib -name 'libstdc++.so.6') | grep GLIBCXX | grep 3.4.32 \ +" + +RUN bash -c "\ + source /opt/conda/etc/profile.d/conda.sh && \ + conda activate ktransformers && \ + pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu && \ + pip uninstall -y torch torchvision torchaudio && \ + pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu && \ + pip uninstall -y intel-opencl-rt dpcpp-cpp-rt && \ + pip list \ +" + +# Clone and set up ktransformers repo +RUN bash -c "\ + source $CONDA_DIR/etc/profile.d/conda.sh && \ + conda activate ktransformers && \ + git clone https://github.com/kvcache-ai/ktransformers.git && \ + cd ktransformers && \ + git submodule update --init && \ + sed -i 's/torch\.xpu\.is_available()/True/g' setup.py && \ + bash install.sh --dev xpu \ +" + +# Init conda and prepare bashrc +RUN conda init bash && \ + echo "source $CONDA_DIR/etc/profile.d/conda.sh" >> ~/.bashrc && \ + echo "conda activate ktransformers" >> ~/.bashrc + +WORKDIR /ktransformers/ +CMD ["bash"] diff --git a/README.md b/README.md index 6ad3a59..b5f62da 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)). + * **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md)) https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2 @@ -116,6 +118,16 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 Getting started with KTransformers is simple! Follow the steps below to set up and start using it. +we have already supported vendors: + +- Metax +- Sanechips (ZhuFeng V1.0) +- Intel +- Ascend +- Kunpeng +- AMD + + ### 📥 Installation To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html). diff --git a/WeChatGroup.png b/WeChatGroup.png index c7f3c2d..8a53460 100644 Binary files a/WeChatGroup.png and b/WeChatGroup.png differ diff --git a/csrc/ktransformers_ext/CMakeLists.txt b/csrc/ktransformers_ext/CMakeLists.txt index 217de78..0ed4ef4 100644 --- a/csrc/ktransformers_ext/CMakeLists.txt +++ b/csrc/ktransformers_ext/CMakeLists.txt @@ -41,6 +41,7 @@ option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON) option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF) option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF) +option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF) # Architecture specific # TODO: probably these flags need to be tweaked on some architectures @@ -303,6 +304,8 @@ elseif (UNIX) message(STATUS "MUSA Toolkit found") add_compile_definitions(KTRANSFORMERS_USE_MUSA=1) endif() + elseif (KTRANSFORMERS_USE_XPU) + add_compile_definitions(KTRANSFORMERS_USE_XPU=1) else() find_package(CUDA REQUIRED) include_directories("${CUDA_INCLUDE_DIRS}") @@ -361,6 +364,7 @@ elseif(UNIX) message(STATUS "Building for HIP") elseif(KTRANSFORMERS_USE_MUSA) target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart) + elseif(KTRANSFORMERS_USE_XPU) else() target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so") endif() diff --git a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h index 9c7e781..7b1d898 100644 --- a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h +++ b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h @@ -17,6 +17,7 @@ #include #include #include + #include #ifdef KTRANSFORMERS_USE_CUDA #include "vendors/cuda.h" #elif KTRANSFORMERS_USE_MUSA @@ -66,10 +67,14 @@ } void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) { + #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM) void (*func)(void*) = (void (*)(void*))params.first; void* args = (void*)params.second; *((CPUInfer**)args) = this; cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args); + #else + throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma"); + #endif } static void sync_(void* cpu_infer_ptr) { @@ -78,7 +83,11 @@ } void sync_with_cuda_stream(intptr_t user_cuda_stream) { + #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM) cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this); + #else + throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma"); + #endif } public: diff --git a/csrc/ktransformers_ext/cuda/test_dequant.py b/csrc/ktransformers_ext/cuda/test_dequant.py index abca745..c39d6c7 100644 --- a/csrc/ktransformers_ext/cuda/test_dequant.py +++ b/csrc/ktransformers_ext/cuda/test_dequant.py @@ -1,7 +1,7 @@ import os import sys sys.path.insert(0,"/home/zbx/ktransformers") -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader import torch gguf_loader_1 = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf") diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp index 2767679..f0aeaa5 100644 --- a/csrc/ktransformers_ext/ext_bindings.cpp +++ b/csrc/ktransformers_ext/ext_bindings.cpp @@ -9,7 +9,7 @@ **/ // Python bindings #include "cpu_backend/cpuinfer.h" -#ifndef KTRANSFORMERS_USE_ROCM +#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) #include "device_launch_parameters.h" #endif #include "llamafile/flags.h" diff --git a/csrc/ktransformers_ext/operators/amx/la/amx.hpp b/csrc/ktransformers_ext/operators/amx/la/amx.hpp index 3338e09..866300d 100644 --- a/csrc/ktransformers_ext/operators/amx/la/amx.hpp +++ b/csrc/ktransformers_ext/operators/amx/la/amx.hpp @@ -843,7 +843,7 @@ inline void mat_mul(int m, int n, int k, std::shared_ptrget_submat(m, k, m_begin, k_block_begin + k_begin); __m512bh *b512 = (__m512bh *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); - for (int m_i = 0; m_i < m; m_i++) { + for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { @@ -914,7 +914,7 @@ inline void mat_mul(int m, int n, int k, std::shared_ptrget_submat(m, k, m_begin, k_block_begin + k_begin); __m512i *b512 = (__m512i *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); - for (int m_i = 0; m_i < m; m_i++) { + for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { diff --git a/csrc/ktransformers_ext/operators/amx/moe.hpp b/csrc/ktransformers_ext/operators/amx/moe.hpp index 7e966ae..81df642 100644 --- a/csrc/ktransformers_ext/operators/amx/moe.hpp +++ b/csrc/ktransformers_ext/operators/amx/moe.hpp @@ -272,8 +272,8 @@ public: void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output, int *batch_size_tensor, Backend *backend) { - bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num); qlen = batch_size_tensor[0]; + bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num); int activated_expert = 0; for (int i = 0; i < config_.expert_num; i++) { m_local_num_[i] = 0; @@ -395,4 +395,4 @@ public: } }; -#endif \ No newline at end of file +#endif diff --git a/doc/README.md b/doc/README.md index 05df2d3..199b990 100644 --- a/doc/README.md +++ b/doc/README.md @@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./en/xpu.md)). * **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./en/llama4.md)). * **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./en/balance-serve.md)). * **Mar 27, 2025**: Support Multi-concurrency. diff --git a/doc/en/Docker_xpu.md b/doc/en/Docker_xpu.md new file mode 100644 index 0000000..cb92d01 --- /dev/null +++ b/doc/en/Docker_xpu.md @@ -0,0 +1,94 @@ +# Intel GPU Docker Guide (Beta) + +## Prerequisites + +* Docker must be installed and running on your system. +* Create a folder to store big models & intermediate files (e.g., /mnt/models) +* **Before proceeding, ensure the Intel GPU driver is installed correctly on your host:** [Installation Guide](./xpu.md#1-install-intel-gpu-driver) + +--- + +## Building the Docker Image Locally + +1. Clone the repository and navigate to the project directory: + + ```bash + git clone https://github.com/kvcache-ai/ktransformers.git + cd ktransformers + ``` + +2. Build the Docker image using the XPU-specific [Dockerfile.xpu](../../Dockerfile.xpu): + + ```bash + sudo http_proxy=$HTTP_PROXY \ + https_proxy=$HTTPS_PROXY \ + docker build \ + --build-arg http_proxy=$HTTP_PROXY \ + --build-arg https_proxy=$HTTPS_PROXY \ + -t kt_xpu:0.3.1 \ + -f Dockerfile.xpu \ + . + ``` + +--- + +## Running the Container + +### 1. Start the container + +```bash +sudo docker run -td --privileged \ + --net=host \ + --device=/dev/dri \ + --shm-size="16g" \ + -v /path/to/models:/models \ + -e http_proxy=$HTTP_PROXY \ + -e https_proxy=$HTTPS_PROXY \ + --name ktransformers_xpu \ + kt_xpu:0.3.1 +``` + +**Note**: Replace `/path/to/models` with your actual model directory path (e.g., `/mnt/models`). + +--- + +### 2. Access the container + +```bash +sudo docker exec -it ktransformers_xpu /bin/bash +``` + +--- + +### 3. Set required XPU environment variables (inside the container) + +```bash +export SYCL_CACHE_PERSISTENT=1 +export ONEAPI_DEVICE_SELECTOR=level_zero:0 +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +``` + +--- + +### 4. Run the sample script + +```bash +python ktransformers/local_chat.py \ + --model_path deepseek-ai/DeepSeek-R1 \ + --gguf_path \ + --optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \ + --cpu_infer \ + --device xpu \ + --max_new_tokens 200 +``` + +**Note**: + +* Replace `` with the path to your GGUF model files. +* Replace `` with the number of CPU cores you want to use plus one. + +--- + +## Additional Information + +For more configuration options and usage details, refer to the [project README](../../README.md). To run KTransformers natively on XPU (outside of Docker), please refer to [xpu.md](./xpu.md). diff --git a/doc/en/install.md b/doc/en/install.md index 031b541..49cc0f9 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -45,7 +45,7 @@ Some preparation: sudo apt-get update sudo apt-get install build-essential cmake ninja-build patchelf ``` -- We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX-3.4.32` +- We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX_3.4.32` ```sh conda create --name ktransformers python=3.11 diff --git a/doc/en/xpu.md b/doc/en/xpu.md new file mode 100644 index 0000000..78a1923 --- /dev/null +++ b/doc/en/xpu.md @@ -0,0 +1,134 @@ +# Intel GPU Support for KTransformers (Beta) + +## Introduction + +### Overview +We are excited to introduce **Intel GPU support** in KTransformers (Beta release). This implementation has been tested and developed using Intel Xeon Scalable processors and Intel Arc GPUs (such as A770 and B580). + +## Installation Guide + +### 1. Install Intel GPU Driver +Begin by installing the GPU drivers for your Intel GPU: +- [Official GPU Installation Guide for Intel GPUs](https://dgpu-docs.intel.com/driver/overview.html) + +To verify that the kernel and compute drivers are installed and functional: + +```bash +clinfo --list | grep Device + `-- Device #0: 13th Gen Intel(R) Core(TM) i9-13900K + `-- Device #0: Intel(R) Arc(TM) A770 Graphics + `-- Device #0: Intel(R) UHD Graphics 770 +``` + +> [!Important] +> Ensure that **Resizable BAR** is enabled in your system's BIOS before proceeding. This is essential for optimal GPU performance and to avoid potential issues such as `Bus error (core dumped)`. For detailed steps, please refer to the official guidance [here](https://www.intel.com/content/www/us/en/support/articles/000090831/graphics.html). + +### 2. Set Up Conda Environment +We recommend using Miniconda3/Anaconda3 for environment management: + +```bash +# Download Miniconda +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh + +# Create environment +conda create --name ktransformers python=3.11 +conda activate ktransformers + +# Install required libraries +conda install -c conda-forge libstdcxx-ng + +# Verify GLIBCXX version (should include 3.4.32) +strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX +``` + +> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3` + +### 3. Install PyTorch and IPEX-LLM +Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm): + +```bash +pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu +pip uninstall torch torchvision torchaudio +pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7 +pip uninstall intel-opencl-rt dpcpp-cpp-rt +``` + +### 4. Build ktransformers + +```bash +# Clone repository +git clone https://github.com/kvcache-ai/ktransformers.git +cd ktransformers +git submodule update --init + +# Install dependencies +bash install.sh --dev xpu +``` + +## Running DeepSeek-R1 Models + +### Configuration for 16B VRAM GPUs +Use our optimized configuration for constrained VRAM: + +```bash +export SYCL_CACHE_PERSISTENT=1 +export ONEAPI_DEVICE_SELECTOR=level_zero:0 +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 + +python ktransformers/local_chat.py \ + --model_path deepseek-ai/DeepSeek-R1 \ + --gguf_path \ + --optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \ + --cpu_infer \ + --device xpu \ + --max_new_tokens 200 +``` + +## Known Limitations +- Serving function is not supported on Intel GPU platform for now + +## Troubleshooting +1. Best Known Config (BKC) to obtain best performance + +To obtain best performance on Intel GPU platform, we recommend to lock GPU frequency and set CPU to performance mode by below settings. +```bash +echo "performance" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +echo 0 | sudo tee /sys/devices/system/cpu/cpu*/power/energy_perf_bias +# 2400 is max frequency for Arc A770 +sudo xpu-smi config -d 0 -t 0 --frequencyrange 2400,2400 +# 2850 is max frequency for Arc B580 +# sudo xpu-smi config -d 0 -t 0 --frequencyrange 2850,2850 +``` + +2. Runtime error like `xpu/sycl/TensorCompareKernels.cpp:163: xxx. Aborted (core dumped)` + +This error is mostly related to GPU driver. If you meet such error, you could update your `intel-level-zero-gpu` to `1.3.29735.27-914~22.04` (which is a verified version by us) by below command. +```bash +wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \ +sudo gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg +echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy client" | \ +sudo tee /etc/apt/sources.list.d/intel-gpu-jammy.list +sudo apt update +# or sudo apt update --allow-insecure-repositories +sudo apt install intel-level-zero-gpu=1.3.29735.27-914~22.04 +``` + +3. `ImportError: cannot import name 'intel' from 'triton._C.libtriton'` + +Installing Triton causes pytorch-triton-xpu to stop working. You can resolve the issue with following command: +```bash +pip uninstall triton pytorch-triton-xpu +# Reinstall correct version of pytorch-triton-xpu +pip install pytorch-triton-xpu==3.3.0 --index-url https://download.pytorch.org/whl/xpu +``` + +4. `ValueError: Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.` + +Ensure you have permissions to access /dev/dri/renderD*. This typically requires your user to be in the render group: +```bash +sudo gpasswd -a ${USER} render +newgrp render +``` + +## Additional Information +To run KTransformers on XPU with Docker, please refer to [Docker_xpu.md](./Docker_xpu.md). diff --git a/install-with-cache.sh b/install-with-cache.sh new file mode 100755 index 0000000..cef4341 --- /dev/null +++ b/install-with-cache.sh @@ -0,0 +1,26 @@ +#!/bin/bash +set -e + +# clear build dirs +# rm -rf build +# rm -rf *.egg-info +# rm -rf csrc/build +# rm -rf csrc/ktransformers_ext/build +# rm -rf csrc/ktransformers_ext/cuda/build +# rm -rf csrc/ktransformers_ext/cuda/dist +# rm -rf csrc/ktransformers_ext/cuda/*.egg-info +rm -rf ~/.ktransformers +echo "Installing python dependencies from requirements.txt" +pip install -r requirements-local_chat.txt +pip install -r ktransformers/server/requirements.txt +echo "Installing ktransformers" +KTRANSFORMERS_FORCE_BUILD=TRUE USE_BALANCE_SERVE=1 pip install -v . --no-build-isolation +pip install third_party/custom_flashinfer/ -v + +# SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") +# echo "Copying thirdparty libs to $SITE_PACKAGES" +# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/ +# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython* + + +echo "Installation completed successfully" diff --git a/install.sh b/install.sh index 06866c6..573826e 100644 --- a/install.sh +++ b/install.sh @@ -1,6 +1,20 @@ #!/bin/bash set -e +# default backend +DEV="cuda" + +# parse --dev argument +while [[ "$#" -gt 0 ]]; do + case $1 in + --dev) DEV="$2"; shift ;; + *) echo "Unknown parameter passed: $1"; exit 1 ;; + esac + shift +done +export DEV_BACKEND="$DEV" +echo "Selected backend: $DEV_BACKEND" + # clear build dirs rm -rf build rm -rf *.egg-info @@ -13,14 +27,17 @@ rm -rf ~/.ktransformers echo "Installing python dependencies from requirements.txt" pip install -r requirements-local_chat.txt pip install -r ktransformers/server/requirements.txt + echo "Installing ktransformers" KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation -pip install third_party/custom_flashinfer/ +if [[ "$DEV_BACKEND" == "cuda" ]]; then + echo "Installing custom_flashinfer for CUDA backend" + pip install third_party/custom_flashinfer/ +fi # SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") # echo "Copying thirdparty libs to $SITE_PACKAGES" # cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/ # patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython* - echo "Installation completed successfully" \ No newline at end of file diff --git a/ktransformers/__init__.py b/ktransformers/__init__.py index fa10c92..1fa9717 100644 --- a/ktransformers/__init__.py +++ b/ktransformers/__init__.py @@ -8,4 +8,4 @@ Version : 1.0.0 LastEditors : chenxl LastEditTime : 2025-02-15 03:53:02 ''' -__version__ = "0.3" +__version__ = "0.3.1" diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index c4f6186..ed1713b 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -29,7 +29,7 @@ model: gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF device: cuda:0 - cache_lens: 8192 + cache_lens: 16384 max_new_tokens: 500 web: mount: False diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 928de48..75e12fb 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM -from ktransformers.util.utils import prefill_and_generate, get_compute_capability +from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model from ktransformers.server.config.config import Config from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor @@ -63,18 +63,24 @@ def local_chat( prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, - chunk_size: int = 8192 + chunk_size: int = 8192, + device: str = "cuda" ): torch.set_grad_enabled(False) Config().cpu_infer = cpu_infer + Config().chunk_size = chunk_size + if torch.xpu.is_available(): + use_cuda_graph = False tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if mode == 'long_context': assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" torch.set_default_dtype(torch.float16) + elif xpu_fp16_model(config): + torch.set_default_dtype(torch.float16) else: torch.set_default_dtype(config.torch_dtype) @@ -89,11 +95,16 @@ def local_chat( config._attn_implementation = "eager" if "Mixtral" in config.architectures[0]: config._attn_implementation = "flash_attention_2" - + if torch.xpu.is_available(): + config._attn_implementation = "eager" model = custom_models[config.architectures[0]](config) else: + if torch.xpu.is_available(): + attn_implementation = "eager" + else: + attn_implementation = "flash_attention_2" model = AutoModelForCausalLM.from_config( - config, trust_remote_code=True, attn_implementation="flash_attention_2" + config, trust_remote_code=True, attn_implementation=attn_implementation ) if optimize_config_path is None: @@ -109,7 +120,7 @@ def local_chat( gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" ) - optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) + optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, default_device=device) try: model.generation_config = GenerationConfig.from_pretrained(model_path) @@ -172,12 +183,12 @@ def local_chat( if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA: generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, + model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim ) else: generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, + model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, ) diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index e4a271e..350af73 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -66,7 +66,7 @@ class StaticCache(transformers.StaticCache): self.page_table_list = [] for idx in range(config.num_hidden_layers): if isinstance(device, dict): - target_device = device[f"blk.{idx}.self_attn"]["generate_device"] + target_device = device[f"model.layers.{idx}.self_attn"]["generate_device"] else: target_device = device @@ -91,7 +91,7 @@ class StaticCache(transformers.StaticCache): # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. if isinstance(device, dict): - target_device = device[f"blk.{idx}.self_attn"]["generate_device"] + target_device = device[f"model.layers.{idx}.self_attn"]["generate_device"] else: target_device = device @@ -213,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module): self.v_caches = [] - def load(self, inference_context: "sched_ext.InferenceContext"): + def load(self, inference_context: "sched_ext.InferenceContext"): for i in range(self.config.num_hidden_layers): self.k_caches.append( @@ -293,7 +293,7 @@ class KGQACache(nn.Module): self.v_caches = [] - def load(self, inference_context: sched_ext.InferenceContext): + def load(self, inference_context: "sched_ext.InferenceContext"): print(self.config.num_hidden_layers) for i in range(self.config.num_hidden_layers): self.k_caches.append( diff --git a/ktransformers/models/custom_modeling_qwen2_moe.py b/ktransformers/models/custom_modeling_qwen2_moe.py index 5740c14..1c84cbf 100644 --- a/ktransformers/models/custom_modeling_qwen2_moe.py +++ b/ktransformers/models/custom_modeling_qwen2_moe.py @@ -39,7 +39,7 @@ class KQwen2MoeForCausalLM(Qwen2MoePreTrainedModel): self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.attn = [None] * 10 + self.attn = [None] * 100 def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) diff --git a/ktransformers/models/custom_modeling_qwen3_moe.py b/ktransformers/models/custom_modeling_qwen3_moe.py index 1cb8c46..32b9797 100644 --- a/ktransformers/models/custom_modeling_qwen3_moe.py +++ b/ktransformers/models/custom_modeling_qwen3_moe.py @@ -39,7 +39,7 @@ class KQwen3MoeForCausalLM(Qwen3MoePreTrainedModel): self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.attn = [None] * 10 + self.attn = [None] * 100 def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py index e14a521..f6845ec 100644 --- a/ktransformers/models/modeling_deepseek.py +++ b/ktransformers/models/modeling_deepseek.py @@ -107,6 +107,7 @@ class DeepseekV2RMSNorm(nn.Module): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.hidden_size = hidden_size def forward(self, hidden_states): input_dtype = hidden_states.dtype diff --git a/ktransformers/models/modeling_deepseek_v3.py b/ktransformers/models/modeling_deepseek_v3.py index f296d9f..3a59d77 100644 --- a/ktransformers/models/modeling_deepseek_v3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, @@ -1598,7 +1599,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): return causal_mask -class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index 75d1a6e..85d6556 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -23,7 +23,7 @@ from ktransformers.models.modeling_deepseek import ( yarn_find_correction_range ) from ktransformers.operators.base_operator import BaseInjectedModule -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.utils import InferenceState from transformers.configuration_utils import PretrainedConfig import torch diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 0f5f9ae..9dfdbdc 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -13,9 +13,10 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.utils import get_compute_capability import logging from transformers.configuration_utils import PretrainedConfig @@ -587,6 +588,100 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): return attn_output, None, past_key_value + def forward_xpu( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + position_embeddings = kwargs.get("position_embeddings", None) + if position_embeddings is not None: + cos, sin = position_embeddings + key_states = torch.cat( + [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])], + dim=-1 + ) + from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced + rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :], + key_states[:, :, :, self.qk_nope_head_dim:], + cos, sin, True) + else: + q_nope, q_pe = torch.split( + query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + cos, sin = self.rotary_emb(q_pe, position_ids) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states.half(), value_states.half(), self.layer_idx, cache_kwargs + ) + + attn_weights = None + from ipex_llm.transformers.models.common import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states.half(), key_states, value_states, + attention_mask.half(), q_len == kv_seq_len, self.softmax_scale + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output).to(hidden_states.dtype) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + def forward( self, hidden_states: torch.Tensor, @@ -598,7 +693,21 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: + if torch.xpu.is_available(): + return self.forward_xpu( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + elif (os.name == 'nt' + or get_compute_capability() < 8 + or hidden_states.device.type == 'cpu' + or device_manager.gpu_vendor != GPUVendor.NVIDIA): return self.forward_windows( hidden_states, attention_mask, @@ -762,3 +871,75 @@ class KLlamaAttention(BaseInjectedModule): attn_weights = None return attn_output, attn_weights, past_key_value + + +class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "xpu", + generate_device: str = "xpu", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device" + assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device" + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.Tensor], + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + bsz, q_len, _ = hidden_states.size() + input_dtype = hidden_states.dtype + hidden_shape = (*input_shape, -1, self.head_dim) + + if not hasattr(self, 'qkv_proj'): + from ipex_llm.transformers.models.common import merge_quantized_qkv + merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module) + + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, -1, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.num_key_value_heads], dim=1) + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + if position_embeddings is None: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + cos, sin = position_embeddings + + from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced + rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states.half(), value_states.half(), + self.layer_idx, cache_kwargs) + + attn_weights = None + from ipex_llm.transformers.models.common import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states.half(), key_states, value_states, + attention_mask.half(), q_len == key_states.size(2), self.scaling + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output).to(input_dtype) + return attn_output, attn_weights diff --git a/ktransformers/operators/balance_serve_attention.py b/ktransformers/operators/balance_serve_attention.py index a785413..51695f3 100644 --- a/ktransformers/operators/balance_serve_attention.py +++ b/ktransformers/operators/balance_serve_attention.py @@ -11,7 +11,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader import logging from transformers.configuration_utils import PretrainedConfig from flashinfer import BatchMLAPagedAttentionWrapper @@ -288,3 +288,170 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention): attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors) return attn_output + + +class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + + + def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): + kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) + q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) + out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) + self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, + bias=False, dtype=q_absorb.dtype, device=q_absorb.device) + self.q_absorb.weight.data = q_absorb + self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, + bias=False, dtype=out_absorb.dtype, device=out_absorb.device) + self.out_absorb.weight.data = out_absorb + #del self.orig_module.kv_b_proj + q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) + out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) + return q_absorb, out_absorb + + + + def forward(self, + hidden_states: torch.Tensor, + kv_cache: KDeepSeekV3Cache, + position_ids: torch.Tensor, + wrapper: None, + num_tokens_tensors: torch.Tensor, + page_idx: torch.Tensor, + page_offset: torch.Tensor, + attention_masks: Optional[list[torch.Tensor]] = None, + q_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_indptr: Optional[torch.Tensor] = None, + bsz_tensors: Optional[torch.Tensor] = None, + last_page_len: Optional[torch.Tensor] = None, + ): + # range bsz_tensors + final_attention_output = torch.tensor([], device=hidden_states.device) + for i in range(bsz_tensors[0]): + batch_num_tokens_tensors = q_indptr[i+1] - q_indptr[i] + batch_last_page_len = last_page_len[i] + # kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe + batch_page_idx = page_idx[q_indptr[i]:q_indptr[i+1]] + batch_page_offset = page_offset[q_indptr[i]:q_indptr[i+1]] + # kv_page_nums is the number of pages for the current batch + kv_page_nums = kv_indptr[i+1] - kv_indptr[i] + # kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm) + kv_total_len = kv_page_nums * kv_cache.page_size + if batch_last_page_len is not None: + kv_total_len = kv_total_len - (kv_cache.page_size - batch_last_page_len) + # print(f"kv_total_len's shape {kv_total_len.shape}") + # kv_index is the index of the kv cache pages for the current batch + kv_index = kv_indices[kv_indptr[i]:kv_indptr[i+1]] + # we can index [kv_index, page_offset_indices] to get the kv cache for the current batch + # from q_indptr[i] to q_indptr[i+1] is the range of the current batch + batch_hidden_states = hidden_states[q_indptr[i]:q_indptr[i+1]] + batch_position_ids = position_ids[q_indptr[i]:q_indptr[i+1]] + q_len, _ = batch_hidden_states.size() + # print("q_len -> ", q_len) + + if self.q_lora_rank is None: + q = self.q_proj(batch_hidden_states, batch_num_tokens_tensors) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(batch_hidden_states, batch_num_tokens_tensors), batch_num_tokens_tensors), batch_num_tokens_tensors) + # for v3, bsz, q_len, num_heads(128), qk_head_dim(192=128(nope)+64(rope)) + q = q.view(q_len, self.num_heads, self.q_head_dim) + # q_nope is [q_len, num_heads(128), qk_nope_head_dim(128)] + # q_pe is [q_len, num_heads(128), qk_rope_head_dim(64)] + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + # compressed_kv is [q_len, kv_lora_rank(512) + rope(64)] + compressed_kv = self.kv_a_proj_with_mqa(batch_hidden_states, batch_num_tokens_tensors) + # compressed_kv is [q_len, kv_lora_rank(512)], k_pe is [q_len, rope(64)] + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + compressed_kv = compressed_kv.contiguous() + compressed_kv = self.kv_a_layernorm(compressed_kv, batch_num_tokens_tensors) + # k_pe is [q_len, 1, qk_rope_head_dim(64)] + k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim) + # compressed_kv is [q_len, 1, kv_lora_rank(512)] + compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank) + + cos, sin = self.rotary_emb(q_pe, batch_position_ids.unsqueeze(0)) + # print(f"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}") + q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2) + q_pe = q_pe.squeeze(0) + # q_pe is [num_heads(128), q_len, qk_rope_head_dim(64)] + q_pe.transpose_(0, 1) + if kv_cache is not None: + cache_kwargs = {"sin": sin, "cos": cos, "page_idx": batch_page_idx, "page_offset": batch_page_offset} # Specific to RoPE models + compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, batch_page_idx, batch_page_offset, cache_kwargs) + compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank) + k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim) + # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)] + # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim + q_absorb, out_absorb = self.get_absorbed() + # q_nope is [num_heads(128), q_len, qk_nope_head_dim(128)] + q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below + # q_nope is [num_heads(128), q_len, kv_lora_rank(512)] + q_nope = torch.matmul(q_nope, q_absorb) # batched MM + + # # q_nope is [q_len, num_heads(128), kv_lora_rank(512)] + # q_nope = q_nope.transpose(0, 1) + + # we need to index out the compressed_kv and k_pe for the current batch + batch_compressed_kv = None + batch_k_pe = None + for page_index in kv_index: + if kv_total_len > kv_cache.page_size: + tmp_compressed_kv = compressed_kv[page_index, 0:kv_cache.page_size, :] + tmp_k_pe = k_pe[page_index, 0:kv_cache.page_size, :] + if batch_compressed_kv is None or batch_k_pe is None: + batch_compressed_kv = tmp_compressed_kv + batch_k_pe = tmp_k_pe + else: + batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0) + batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0) + kv_total_len -= kv_cache.page_size + else: + tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :] + tmp_k_pe = k_pe[page_index, 0:kv_total_len, :] + if batch_compressed_kv is None or batch_k_pe is None: + batch_compressed_kv = tmp_compressed_kv + batch_k_pe = tmp_k_pe + else: + batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0) + batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0) + break + # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)] + # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)] + attention_weights = (torch.matmul(q_pe,batch_k_pe.mT) + torch.matmul(q_nope, batch_compressed_kv.mT)) * self.softmax_scale + # attention_weights is [num_heads(128), q_len, k_len] + + # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(q_len,-1,-1).transpose(0,1) + + # attention_masks[i] is [q_len, k_len] + + attention_weights = (attention_weights + attention_masks[i]) + # attention_weights shape is [num_heads(128), q_len, k_len] + attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=torch.float32).to(q_pe.dtype) + attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),q_len, lora_rank(512)] + # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] + out_absorb = out_absorb.transpose(1,2) + # q for q_len, n for num_heads, h for v_head_dim, v for kv_lora_rank + attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), q_len, v_head_dim(128)] + attn_output = attn_output.transpose(0, 1) # [q_len, num_heads(128), v_head_dim(128)] + attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output, batch_num_tokens_tensors) + final_attention_output = torch.cat((final_attention_output, attn_output), dim=0) + return final_attention_output \ No newline at end of file diff --git a/ktransformers/operators/base_operator.py b/ktransformers/operators/base_operator.py index 0fa2efd..5e49709 100644 --- a/ktransformers/operators/base_operator.py +++ b/ktransformers/operators/base_operator.py @@ -6,7 +6,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from typing import Any from torch import nn, Tensor -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader from transformers.configuration_utils import PretrainedConfig import ktransformers.util.utils as utils class BaseInjectedModule(nn.Module): diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 34f0af0..7a40168 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -26,7 +26,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext import cpuinfer_ext from cpuinfer_ext.moe import MOEConfig, MOE import ctypes -from ktransformers.util.custom_gguf import GGMLQuantizationType, GGUFLoader +from ktransformers.util.custom_gguf import GGMLQuantizationType +from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader, ModelLoader from ktransformers.util.utils import InferenceState from ktransformers.server.config.config import Config from transformers.activations import ACT2FN @@ -39,8 +40,21 @@ from ktransformers.operators.cpuinfer import CPUInfer def deduplicate_and_sort(lst): return sorted(set(lst)) +def generate_cuda_graphs(chunk_size: int) -> list: + assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024" + base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size] + + if chunk_size <= 1024: + return deduplicate_and_sort(base_list) + + multiples = [i for i in range(1024, chunk_size + 1, 1024)] + + return deduplicate_and_sort(base_list + multiples) #cuda_graphs = [Config().chunk_size] -cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size]) +if torch.cuda.is_available(): + cuda_graphs = generate_cuda_graphs(Config().chunk_size) +else: + cuda_graphs = 1 # class Base(BaseInjectedModule, ABC): class KExpertsBase(ABC): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs): @@ -77,7 +91,7 @@ class KExpertsBase(ABC): down_type = None for key in keys: - if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"): targets = [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight" ] tensors = self.load_multi(key, targets, device=device) gate = tensors[".ffn_gate_exps.weight"] @@ -86,7 +100,7 @@ class KExpertsBase(ABC): gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] - elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info: + elif self.gguf_loader.has_tensor(key + ".ffn_down.0.weight"): # for supporting Mixtral-8x7B-Instuct gate = [] up = [] @@ -166,6 +180,11 @@ class KExpertsCPU(KExpertsBase): n_routed_experts = self.n_routed_experts self.cpu_infer = KExpertsCPU.CPU_INFER # n_routed_experts = len(self.orig_module) + model_dtype = torch.get_default_dtype() + if torch.xpu.is_available() and model_dtype == torch.float16: + hidden_type = 1 # fp16 + else: + hidden_type = 30 # bf16 if self.backend == "llamafile": moe_config = MOEConfig( n_routed_experts, @@ -181,7 +200,7 @@ class KExpertsCPU(KExpertsBase): self.gate_type, self.up_type, self.down_type, - 30, # TODO: get from model.dtype + hidden_type, # TODO: get from model.dtype ) self.moe = MOE(moe_config) elif self.backend == "AMXBF16": @@ -194,7 +213,7 @@ class KExpertsCPU(KExpertsBase): self.config.num_experts_per_tok, self.config.hidden_size, self.config.moe_intermediate_size, - 25600, + max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, gate_ptr, up_ptr, down_ptr, @@ -212,7 +231,7 @@ class KExpertsCPU(KExpertsBase): self.config.num_experts_per_tok, self.config.hidden_size, self.config.moe_intermediate_size, - 25600, + max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, gate_ptr, up_ptr, down_ptr, @@ -241,8 +260,12 @@ class KExpertsCPU(KExpertsBase): KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True) KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) - KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) - KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) + if torch.xpu.is_available(): + KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype) + KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True) + else: + KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) + KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): if bsz_tensor is None: @@ -274,9 +297,9 @@ class KExpertsCPU(KExpertsBase): def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): # generate, capture and run cuda graph # print(expert_ids) - if bsz_tensor is None: + if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1): bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32) - if torch.cuda.is_current_stream_capturing(): + if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): if cuda_graph_idx != -1: KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True) @@ -296,6 +319,15 @@ class KExpertsCPU(KExpertsBase): self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device] + elif input_tensor.size(0)==1 and torch.xpu.is_available(): + KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True) + KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True) + KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True) + # KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True) + self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr())) + self.cpu_infer.sync() + KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) + return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1) else: input_tensor = input_tensor.contiguous().cpu() expert_ids = expert_ids.contiguous().cpu() @@ -325,14 +357,19 @@ class KExpertsCPU(KExpertsBase): down_type = None for key in keys: - if self.gguf_loader.safetensor_loader is not None: - # using a temp ugly way to temprary load the tensor - gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy() - up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy() - down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy() - gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item() - up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item() - down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item() + if isinstance(self.gguf_loader, SafeTensorLoader): + res = self.gguf_loader.load_experts(key) + return {key: res} + elif self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"): + gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") + up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") + down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") + # gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] + # up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] + # down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] + gate_type = self.gguf_loader.get_ggml_type(key + ".ffn_gate_exps.weight") + up_type = self.gguf_loader.get_ggml_type(key + ".ffn_up_exps.weight") + down_type = self.gguf_loader.get_ggml_type(key + ".ffn_down_exps.weight") elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") @@ -356,9 +393,9 @@ class KExpertsCPU(KExpertsBase): gate = np.stack(gate) up = np.stack(up) down = np.stack(down) - gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"] - up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"] - down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"] + gate_type = self.gguf_loader.get_ggml_type(key + ".ffn_gate.0.weight") + up_type = self.gguf_loader.get_ggml_type(key + ".ffn_up.0.weight") + down_type = self.gguf_loader.get_ggml_type(key + ".ffn_down.0.weight") else: raise ValueError(f"Experts {key} not found in gguf_loader") res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} @@ -445,7 +482,7 @@ class KExpertsMarlin(KExpertsBase): down = None for key in keys: - if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"): gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") @@ -806,7 +843,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) @@ -906,7 +943,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # only for generate phase - if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) @@ -1106,7 +1143,7 @@ class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE): # only for generate phase - if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity, bsz_tensor).squeeze(0) @@ -1288,7 +1325,7 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): routing_weights = routing_weights.to(hidden_states.dtype) # only for generate phase - if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ @@ -1384,24 +1421,33 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): return final_out class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock): - def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): + def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0): orig_shape = hidden_states.shape sequence_length = orig_shape[1] hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - router_logits = self.gate(hidden_states, bsz_tensor) + if bsz_tensor is None: + router_logits = self.gate(hidden_states) + else: + router_logits = self.gate(hidden_states, bsz_tensor) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + if router_logits.device.type == "xpu": + from ipex_llm.transformers.models.common import moe_softmax_topk + selected_experts, routing_weights = moe_softmax_topk( + router_logits.half(), self.top_k, self.norm_topk_prob + ) + else: + routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) # only for generate phase - if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ diff --git a/ktransformers/operators/flashinfer_batch_prefill_wrapper.py b/ktransformers/operators/flashinfer_batch_prefill_wrapper.py index e934654..287affb 100644 --- a/ktransformers/operators/flashinfer_batch_prefill_wrapper.py +++ b/ktransformers/operators/flashinfer_batch_prefill_wrapper.py @@ -40,7 +40,7 @@ class flashInferAttn(): self.kv_layout = kv_layout self.use_cuda_graph = use_cuda_graph if flashInferAttn.float_workspace_buffer is None: - flashInferAttn.float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=device) + flashInferAttn.float_workspace_buffer = torch.empty(max_batch_token * 1024 * 1024, dtype=torch.uint8, device=device) self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device) diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index cf5799e..f5f96c1 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -6,7 +6,7 @@ import os from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.linear import KTransformersLinear -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader, ModelLoader, SafeTensorLoader from transformers.configuration_utils import PretrainedConfig from abc import ABC, abstractmethod @@ -55,24 +55,20 @@ class KMoEGateBase(ABC): down_type = None for key in keys: - key = ".".join(key.split(".")[:-1]) - if self.gguf_loader.safetensor_loader is not None: - targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] - weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight") - e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias") - weight_type = weight.dtype - e_score_correction_bias_type = e_score_correction_bias.dtype - res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} - elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: - targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] + # key = ".".join(key.split(".")[:-1]) + if isinstance(self.gguf_loader, SafeTensorLoader): + res = self.gguf_loader.load_gate(key, device=device) + elif self.gguf_loader.has_tensor(key+".weight"): + # targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] + targets = [".weight", ".e_score_correction_bias"] tensors = self.load_multi(key, targets, device=device) - weight = tensors[".ffn_gate_inp.weight"] - e_score_correction_bias = tensors[".exp_probs_b.bias"] - weight_type = self.gguf_loader.tensor_info[key + ".ffn_gate_inp.weight"]["ggml_type"] - e_score_correction_bias_type = self.gguf_loader.tensor_info[key + ".exp_probs_b.bias"]["ggml_type"] + weight = tensors[".weight"] + e_score_correction_bias = tensors[".e_score_correction_bias"] + # weight_type = self.gguf_loader.tensor_info[key + ".weight"]["ggml_type"] + res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias} else: raise ValueError(f"Experts {key} not found in gguf_loader") - res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} + return res def load_multi(self, key: str, keys: list[str], device: str = "cpu"): @@ -106,8 +102,6 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): if w is None: w = self.load_weights(device=device) if isinstance(w, dict): - self.weight_type = w["weight_type"] - self.e_score_correction_bias_type = w["e_score_correction_bias_type"] self.orig_module.weight = nn.Parameter(w["weight"]) self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) else: @@ -175,8 +169,6 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase): if w is None: w = self.load_weights(device=device) if isinstance(w, dict): - self.weight_type = w["weight_type"] - self.e_score_correction_bias_type = w["e_score_correction_bias_type"] self.orig_module.weight = nn.Parameter(w["weight"]) self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) else: @@ -190,4 +182,34 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase): if self.weight is not None: self.weight = None if self.e_score_correction_bias is not None: - self.e_score_correction_bias = None \ No newline at end of file + self.e_score_correction_bias = None + + +class KMoEGateIPEXLLM(KMoEGate): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + generate_device: str = "xpu", + prefill_device: str = "xpu", + **kwargs, + ): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) + KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + self.generate_device = generate_device + self.prefill_device = prefill_device + + def forward(self, hidden_states) -> torch.Tensor: + x = hidden_states.view(-1, hidden_states.size(-1)) + logits = torch.nn.functional.linear( + x.type(torch.float32), self.orig_module.weight.type(torch.float32), None + ) + scores = logits.sigmoid() + + from ipex_llm.transformers.models.common import moe_group_topk + topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias, + self.n_group, self.topk_group, self.top_k, + self.norm_topk_prob, self.routed_scaling_factor) + return topk_idx, topk_weight.to(x.dtype) \ No newline at end of file diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py index 62c5cba..796592c 100644 --- a/ktransformers/operators/layernorm.py +++ b/ktransformers/operators/layernorm.py @@ -29,11 +29,12 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm from ktransformers.operators.base_operator import BaseInjectedModule -from ktransformers.util.custom_gguf import GGUFLoader -from flashinfer.norm import ( - fused_add_rmsnorm, - rmsnorm, -) +from ktransformers.util.custom_loader import GGUFLoader +if not torch.xpu.is_available(): + from flashinfer.norm import ( + fused_add_rmsnorm, + rmsnorm, + ) logger = logging.getLogger(__name__) @@ -163,3 +164,62 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule): variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + +class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.hidden_size, + orig_module.variance_epsilon) + + def forward( + self, + x, + batch_size_tensor: torch.Tensor = None, + residual: Optional[torch.Tensor] = None, + )-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + x = x + residual + residual = x + # range batch_size_tensor for x + input_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + if residual is not None: + return self.weight * x.to(input_dtype), residual + return self.weight * x.to(input_dtype) + + +class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "xpu", + generate_device: str = "xpu", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.weight.shape[0], + orig_module.variance_epsilon) + self.eps = orig_module.variance_epsilon + + def forward(self, x: torch.Tensor) -> torch.Tensor: + from ipex_llm.transformers.models.common import rms_norm_forward + if x.dtype not in [torch.float32, torch.float16]: + output = rms_norm_forward(self, x.float()) + else: + output = rms_norm_forward(self, x) + return output.to(x.dtype) + + def load(self): + BaseInjectedModule.load(self) + if self.weight.dtype not in [torch.float32, torch.float16]: + self.weight = self.weight.float() \ No newline at end of file diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 293826e..654c9f9 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -14,18 +14,20 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved. import ctypes import torch from torch import Tensor, nn -import KTransformersOps -import vLLMMarlin -from ktransformers.util.custom_gguf import GGUFLoader +if not torch.xpu.is_available(): + import KTransformersOps + import vLLMMarlin +from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader from ktransformers.util.utils import InferenceState -from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( - MarlinWorkspace, - marlin_quantize, - GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MIN_THREAD_K, - GPTQ_MARLIN_MAX_PARALLEL, - vllm_marlin_quantize -) +if not torch.xpu.is_available(): + from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( + MarlinWorkspace, + marlin_quantize, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MIN_THREAD_K, + GPTQ_MARLIN_MAX_PARALLEL, + vllm_marlin_quantize + ) from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant @@ -83,15 +85,15 @@ class KLinearBase(ABC): keys = [self.key] for key in keys: - if self.gguf_loader.safetensor_loader is not None: + if isinstance(self.gguf_loader, SafeTensorLoader): # using safetensor_loader - tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight') - if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map: - weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv') + tensor = self.gguf_loader.load_tensor(key+'.weight') + if self.gguf_loader.has_tensor(key+'.weight_scale_inv'): + weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv') return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) return nn.Parameter(tensor) - elif key + ".weight" in self.gguf_loader.tensor_file_map: + elif self.gguf_loader.has_tensor(key + ".weight") or "kv_b_proj" in key: if key + ".bias" in self.gguf_loader.tensor_file_map: tensors = self.load_multi(key, ["weight", "bias"], device=device) tensor = tensors["weight"] @@ -99,6 +101,19 @@ class KLinearBase(ABC): # self.qtype = GGML_TYPE_QTYPE_MAP[tensorinfo[key + ".weight"]["ggml_type"]] # print(torch.isinf(tensor).any(), torch.isinf(bias).any()) return nn.Parameter(tensor), nn.Parameter(bias) + elif "kv_b_proj" in key and not self.gguf_loader.has_tensor(key + ".weight"): + attn_k_b_tensors = self.load_multi(key.replace("self_attn.kv_b_proj", "attn_k_b"), ["weight"], device=device) + attn_k_b = attn_k_b_tensors["weight"] + del attn_k_b_tensors + attn_k_b = attn_k_b.transpose(1, 2).contiguous() + attn_v_b_tensors = self.load_multi(key.replace("self_attn.kv_b_proj", "attn_v_b"), ["weight"], device=device) + attn_v_b = attn_v_b_tensors["weight"] + del attn_v_b_tensors + kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1) + kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous() + del attn_k_b + del attn_v_b + return nn.Parameter(kv_b_proj) else: tensors = self.load_multi(key, ["weight"], device=device) tensor = tensors["weight"] @@ -502,6 +517,9 @@ class VLinearMarlin(KLinearBase): marlin_s = self.marlin_s.to(x.dtype) sms = -1 + # padding x.shape[0] to avoid CUDA illegal memory access error + x, orig_size_m = self._pad_input(x) + x = vLLMMarlin.gptq_marlin_gemm( x, self.marlin_q_w, @@ -511,26 +529,15 @@ class VLinearMarlin(KLinearBase): self.workspace.scratch, self.num_bits, bsz_tensor, - # torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device), x.shape[0], self.n, x.shape[-1], sms, self.is_k_full, ) - # x = KTransformersOps.gptq_marlin_gemm( - # x, - # self.marlin_q_w, - # marlin_s, - # self.g_idx, - # self.sort_indices, - # self.workspace.scratch, - # self.num_bits, - # x.shape[0], - # self.n, - # x.shape[-1], - # self.is_k_full, - # ) + + x = x[:orig_size_m] + if self.has_bias: x = x + self.bias orig_shape[-1] = self.n @@ -546,6 +553,27 @@ class VLinearMarlin(KLinearBase): self.sort_indices = None self.workspace = None + def _pad_input(self, x): + + size_m = x.shape[0] + size_k = x.shape[1] + + # size_m and align value depends on VLinearMarlin implementation + if size_m > 1024: + align = 1024 + elif size_m > 64: + align = 64 + else: + align = 1 + + padded_size_m = ((size_m + align - 1) // align) * align + + if padded_size_m > size_m: + pad_len = padded_size_m - size_m + pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device) + x = torch.cat([x, pad_tensor], dim = 0).contiguous() + return x, size_m + class KLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor @@ -760,7 +788,7 @@ class KLinearCPUInfer(KLinearBase): self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device) def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"): - if self.key + ".weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.has_tensor(self.key + ".weight"): if self.key + ".bias" in self.gguf_loader.tensor_file_map: self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight") self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"] @@ -778,6 +806,75 @@ class KLinearCPUInfer(KLinearBase): if self.has_bias: self.bias = None +class KLinearIPEXLLM(KLinearBase): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "xpu", + precision: str = "sym_int4", + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.has_bias = False + self.dtype = torch.get_default_dtype() + self.weight = None + self.has_bias = False + self.precision = precision + self.qtype = None + + def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor: + dtype = x.dtype + out_device = x.device + from ipex_llm.transformers.models.common import linear_forward + x = linear_forward(x.half(), self.weight, self.qtype, self.out_features) + + if self.has_bias: + x = x + self.bias + x = x.to(dtype=dtype, device=out_device) + return x + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if self.loaded: return + if device is None: device = self.device + assert device.lower()[:3] == "xpu", "IPEX-LLM quantized linear only supports XPU device" + if w is None: w = self.load_weight(device=device) + + if isinstance(w, nn.Parameter): + try: + weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T + except: + weight = w.to(dtype=self.dtype).T + self.has_bias = False + elif isinstance(w, tuple): + try: + weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T + except: + weight = w[0].to(dtype=self.dtype).T + self.bias = w[1].to(dtype=self.dtype) + self.has_bias = True + else: + raise ValueError("Invalid weight type") + weight = weight.to("cpu").float().transpose(0, 1).contiguous() + + if self.has_bias: + self.bias = self.bias.to(device) + + # quantize linear weight + from ipex_llm.transformers.models.common import quantize_linear + paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision) + self.weight = paramsLowBit.to(device) + self.qtype = qtype + self.loaded = True + + def unload(self): + if self.weight is not None: + self.weight = None + if self.has_bias: + self.bias = None + LINEAR_MAP = { "KLinearMarlin": KLinearMarlin, "KLinearTorch": KLinearTorch, @@ -785,6 +882,7 @@ LINEAR_MAP = { "VLinearMarlin": VLinearMarlin, "KLinearFP8": KLinearFP8, "KLinearQ8": KLinearQ8, + "KLinearIPEXLLM": KLinearIPEXLLM, } class KTransformersLinear(BaseInjectedModule, KLinearBase): diff --git a/ktransformers/operators/mlp.py b/ktransformers/operators/mlp.py index 02648b1..77d7d05 100644 --- a/ktransformers/operators/mlp.py +++ b/ktransformers/operators/mlp.py @@ -1,6 +1,6 @@ from ktransformers.operators.base_operator import BaseInjectedModule -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader from transformers import PretrainedConfig import torch.nn as nn from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index bbac29a..e136b57 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -58,7 +58,7 @@ from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.utils import InferenceState, get_compute_capability -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader from transformers.configuration_utils import PretrainedConfig from ktransformers.models.modeling_llama import ( LlamaDecoderLayer, @@ -306,6 +306,12 @@ class KQwen2MoeModel(BaseInjectedModule): hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + if torch.xpu.is_available() and inputs_embeds.device.type == "xpu": + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + position_embeddings = None + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -369,6 +375,7 @@ class KQwen2MoeModel(BaseInjectedModule): output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) if per_layer_prefill_flag: # print(f"to cpu") @@ -376,8 +383,10 @@ class KQwen2MoeModel(BaseInjectedModule): torch.cuda.empty_cache() hidden_states = layer_outputs[0] - if use_cache: + if use_cache and len(layer_outputs) > 1: next_decoder_cache = layer_outputs[2 if output_attentions else 1] + else: + next_decoder_cache = None if output_attentions: all_self_attns += (layer_outputs[1],) @@ -396,11 +405,14 @@ class KQwen2MoeModel(BaseInjectedModule): next_cache = None if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) + if next_decoder_cache is not None: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + else: + next_cache = past_key_values if not return_dict: return tuple( @@ -647,10 +659,20 @@ class KDeepseekV2Model(BaseInjectedModule): if position_ids is None: position_ids = cache_position.unsqueeze(0) + if inputs_embeds.device.type == "xpu" and position_ids is not None: + cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds, + position_ids) + position_embeddings = (cos, sin) + else: + position_embeddings = None + if per_layer_prefill_flag: causal_mask = None else: - if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: + if (os.name == 'nt' + or get_compute_capability() < 8 + or (self.transfer_map is not None and 'cpu' in self.transfer_map.values()) + or device_manager.gpu_vendor != GPUVendor.NVIDIA): # print("for Windows or GPU before ampere, use forward_windows") # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( @@ -734,6 +756,7 @@ class KDeepseekV2Model(BaseInjectedModule): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) t5 = time.time() if per_layer_prefill_flag: diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py index 331e6cf..bbe08c8 100644 --- a/ktransformers/optimize/optimize.py +++ b/ktransformers/optimize/optimize.py @@ -12,7 +12,7 @@ from torch import nn from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig # from operators import BaseInjectedModule -from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf +from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory from ktransformers.util.utils import set_module, load_weights import itertools import copy @@ -54,7 +54,7 @@ def del_meta(module:nn.Module): def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"): module_name = prefix[:-1] - translated_name = translate_name_to_gguf(prefix)[:-1] + # translated_name = translate_name_to_gguf(prefix)[:-1] #print("gen_optimize_config", prefix, module_name, translated_name) recursive = True for rule in rule_list: @@ -76,7 +76,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p if "replace" in rule: replace_meta = rule["replace"] if module_name not in out_data: - out_data[module_name]={"key": translated_name, + out_data[module_name]={"key": module_name, "class": replace_meta["class"] if "class" in replace_meta else "default", # "device": replace_meta["device"] if "device" in replace_meta else default_device, "kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()} @@ -91,7 +91,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p if module_name not in out_data: out_data[module_name]= { "class": "default", - "key": translated_name, + "key": module_name, "kwargs": {"generate_device": default_device, "prefill_device": default_device} } @@ -103,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + "." - gen_optimize_config(child, out_data, rule_list, child_prefix) + gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device) def translate_model_config(model_config: PretrainedConfig): @@ -123,12 +123,15 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo model_config = translate_model_config(model_config) - gguf_loader=GGUFLoader(gguf_path) + weights_loader = ModelLoaderFactory.create_loader(gguf_path) with torch.device("meta"): - inject(module, optimize_config, model_config, gguf_loader) + inject(module, optimize_config, model_config, weights_loader) # pre load lm_head because its big inter result - load_weights(module.lm_head, gguf_loader, "lm_head.") - load_weights(module, gguf_loader) - module.gguf_loader = gguf_loader + load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device) + load_weights(module, weights_loader, device=default_device) + module.gguf_loader = weights_loader del_meta(module) - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-gpu-cpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-gpu-cpu.yaml new file mode 100644 index 0000000..3425add --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-gpu-cpu.yaml @@ -0,0 +1,184 @@ +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +# === Rotary Embedding Replacement === + +# GPU 0: layers 0–9 +- match: + name: "^model\\.layers\\.(0|[1-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +# CPU: layers 10-29 +- match: + name: "^model\\.layers\\.([12][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +# === Linear Layers Replacement (excluding self_attn) === + +# GPU 0: layers 0–9 +- match: + name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +# CPU: layers 10-29 +- match: + name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + generate_op: "KLinearCPUInfer" + prefill_op: "KLinearTorch" + out_device: "cpu" + +# === MLP (MoE) Replacement === + +# GPU 0: layers 0–9 +- match: + name: "^model\\.layers\\.(0|[1-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +# CPU: layers 10-29 +- match: + name: "^model\\.layers\\.([12][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +# === MLP Gate Replacement === + +# GPU 0: layers 0–9 +- match: + name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +# CPU: layers 10-29 +- match: + name: "^model\\.layers\\.([12][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +# === MLP Experts Replacement === + +# GPU 0: layers 0–9 +- match: + name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module +# CPU: layers 10-29 +- match: + name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cpu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cpu" + recursive: False # don't recursively inject submodules of this module + +# === Self-Attention Replacement === + +# GPU 0: layers 0–9 +- match: + name: "^model\\.layers\\.(0|[1-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +# CPU: layers 10-29 +- match: + name: "^model\\.layers\\.([12][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +# === Overall Model Replacement with Transfer Map === + +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 10: "cpu" + +# === Default Catch-All for Other Modules ===# +# GPU 0: layers 0–9 +- match: + name: "^model\\.layers\\.(0|[1-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +#lmm_head on GPU 0 +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +# CPU: layers 10-29 +- match: + name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml new file mode 100644 index 0000000..670f6d5 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml @@ -0,0 +1,91 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearFP8" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + backend: "llamafile" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm + replace: + class: ktransformers.operators.layernorm.RMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP + replace: + class: ktransformers.operators.mlp.kDeepseekV3MLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml new file mode 100644 index 0000000..5de582f --- /dev/null +++ b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml @@ -0,0 +1,64 @@ +- match: + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "xpu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "xpu" + recursive: False # don't recursively inject submodules of this module +- match: + class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm + replace: + class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + device: "xpu" +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml new file mode 100644 index 0000000..c0e46c3 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml @@ -0,0 +1,81 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\..*" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm + replace: + class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGateIPEXLLM + kwargs: + generate_device: "xpu:0" + prefill_device: "xpu:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "xpu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "xpu" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml new file mode 100644 index 0000000..6bb4dae --- /dev/null +++ b/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml @@ -0,0 +1,80 @@ +- match: + name: "rotary_emb$" + replace: + class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\.(?!.*mlp\\.gate).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "xpu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "xpu" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KQwen3MoeAttentionIPEXLLM + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KQwen2MoeModel" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" +- match: + class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm + replace: + class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP + replace: + class: ktransformers.operators.mlp.KQwen2MoeMLP + kwargs: + generate_device: "xpu" + prefill_device: "xpu" diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 1210e14..748bd47 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -128,10 +128,7 @@ class ArgumentParser: else: args.model_dir = self.cfg.model_dir args.model_path = self.cfg.model_path - # set config from args - for key, value in vars(args).items(): - if value is not None and hasattr(self.cfg, key): - setattr(self.cfg, key, value) + # we add the name not match args individually self.cfg.model_device = args.device self.cfg.mount_web = args.web @@ -140,10 +137,15 @@ class ArgumentParser: self.cfg.user_force_think = args.force_think model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) - if args.architectures == "Qwen3MoeForCausalLM" or args.architectures == "Qwen2MoeForCausalLM" : + if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" : args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim + args.architectures = model_config.architectures[0] else: args.gpu_memory_size = args.cache_lens*2*576*61 + # set config from args + for key, value in vars(args).items(): + if value is not None and hasattr(self.cfg, key): + setattr(self.cfg, key, value) self.cfg.gpu_memory_size = args.gpu_memory_size free_ports = get_free_ports(3, [args.port]) args.sched_port = free_ports[0] diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 2d89332..a385c9c 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -195,13 +195,13 @@ class Engine: self.block_num = inference_context.k_cache[0].size(1) + self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num) #@TODO add config if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM": - self.model.init_wrapper(self.args.use_cuda_graph, self.device, 1024 ,args.max_batch_size, self.block_num) # TODO: 1024 is a magic number(max_batch_tokens) + self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num) else: self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) - self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num) self.sampler = Sampler() self.query_manager = QueryManager(device = self.device, page_size = args.page_size) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 6bde540..78cb73f 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -11,6 +11,14 @@ from transformers import ( StaticCache, AutoModelForCausalLM, BitsAndBytesConfig, + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + MinPLogitsWarper, + TypicalLogitsWarper, + EpsilonLogitsWarper, + EtaLogitsWarper, ) from ktransformers.server.config.config import Config @@ -206,6 +214,58 @@ class TransformersInterface(BackendInterfaceBase): self.seq_length += 1 return self.streamer.put(new_tokens) + @staticmethod + def tf_logits_warper(generation_config): + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances + used for multinomial sampling. + """ + + # instantiate warpers list + warpers = LogitsProcessorList() + + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config._eos_token_tensor, list): + min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 + elif isinstance(generation_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.min_p is not None: + # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) + warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + warpers.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + warpers.append( + EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + warpers.append( + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device + ) + ) + # `LogitNormalization` should always be the last logit processor, when present + if generation_config.renormalize_logits is True: + warpers.append(LogitNormalization()) + return warpers + def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None): if temperature is None or temperature == 0: temperature = self.model.generation_config.temperature @@ -222,14 +282,8 @@ class TransformersInterface(BackendInterfaceBase): repetition_penalty=self.args.repetition_penalty # change this to modify generate config ) self.inputs = inputs - try: # transformers==4.43 - self.logits_warper = ( - self.model._get_logits_warper(generation_config, device=device) - ) - except: - self.logits_warper = ( - self.model._get_logits_warper(generation_config) - ) + + self.logits_warper = self.tf_logits_warper(generation_config) def logits_to_token(self, logits: torch.Tensor): logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1)) diff --git a/ktransformers/server/balance_serve/inference/forward_batch.py b/ktransformers/server/balance_serve/inference/forward_batch.py index 7022d9e..26b4d3d 100644 --- a/ktransformers/server/balance_serve/inference/forward_batch.py +++ b/ktransformers/server/balance_serve/inference/forward_batch.py @@ -200,7 +200,7 @@ class ForwardBatchInput: device=None, tokens: torch.Tensor = None, num_mini_batches: int = 1, - max_seq_length: int = 1024, # TODO: add to yaml + max_seq_length: int = 4096, # TODO: add to yaml prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, gen_prefill: bool = True, @@ -223,12 +223,12 @@ class ForwardBatchInput: decode_querys_info = [] for i in range(min(decode_batch_size, cuda_lens)): - query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset) + query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, 256, page_size, device, is_prefill=False, offset=offset) offset += max_seq_length // page_size if tokens is not None: query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens) if decode_active_position is None: - query_info.active_position = prefill_active_length + query_info.active_position = 255 else: query_info.active_position = decode_active_position[i] diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py index 79b3053..55dfb6d 100644 --- a/ktransformers/server/balance_serve/inference/model_runner.py +++ b/ktransformers/server/balance_serve/inference/model_runner.py @@ -39,6 +39,17 @@ def pad_num_tokens(num_tokens): def deduplicate_and_sort(lst): return sorted(set(lst)) +def generate_cuda_graphs(chunk_size: int) -> list: + # 如果输入不符合要求,assert掉 + assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024" + base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size] + + if chunk_size <= 1024: + return deduplicate_and_sort(base_list) + + multiples = [i for i in range(1024, chunk_size + 1, 1024)] + + return deduplicate_and_sort(base_list + multiples) class ModelRunner: """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.""" @@ -56,7 +67,7 @@ class ModelRunner: self.features_buf = None self.output = None self.graph_memory_pool = None - self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size]) + self.cuda_graphs = generate_cuda_graphs(Config().chunk_size) self.use_cuda_graph = use_cuda_graph self.model_time = 0 self.page_size = page_size diff --git a/ktransformers/tests/dequant_gpu.py b/ktransformers/tests/dequant_gpu.py index 0dd5272..3dbd794 100644 --- a/ktransformers/tests/dequant_gpu.py +++ b/ktransformers/tests/dequant_gpu.py @@ -7,7 +7,7 @@ sys.path.append(current_path+"/../..") import numpy as np # from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin # from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_loader import GGUFLoader import torch import KTransformersOps torch.set_default_dtype(torch.bfloat16) diff --git a/ktransformers/tests/dequant_gpu_t.py b/ktransformers/tests/dequant_gpu_t.py index 4b2556d..06de4a0 100644 --- a/ktransformers/tests/dequant_gpu_t.py +++ b/ktransformers/tests/dequant_gpu_t.py @@ -9,7 +9,7 @@ from pycuda.compiler import SourceModule import numpy as np from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch -from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k +from ktransformers.util.custom_loader import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k import torch import KTransformersOps torch.set_default_dtype(torch.bfloat16) diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py index b45bf87..6f435b4 100644 --- a/ktransformers/tests/test_speed.py +++ b/ktransformers/tests/test_speed.py @@ -159,5 +159,7 @@ if __name__ == "__main__": prompt = ktansformer_prompt1024 elif args.prompt_lens == 2048: prompt = ktansformer_prompt1024 * 2 + elif args.prompt_lens == 4096: + prompt = ktansformer_prompt1024 * 4 asyncio.run(main(args.concurrent, prompt, max_tokens, model)) diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index b3d98d3..5e4ffd6 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -24,8 +24,8 @@ from typing import Sequence import os from enum import IntEnum import torch -import KTransformersOps -from .custom_loader import SafeTensorLoader +if not torch.xpu.is_available(): + import KTransformersOps import ctypes import math @@ -166,238 +166,6 @@ DATA_TYPES = { "FP8": 13, } -class GGUFLoader: - tensor_info: dict - gguf_path: str - tensor_file_map: dict # {tensor_name: tensor_file_path} - gguf_file_meta: dict - safetensor_loader: SafeTensorLoader - def __init__(self, gguf_path: str): - # Check dir exist - if not os.path.exists(gguf_path): - raise FileNotFoundError(f"GGUF dir not found: {gguf_path}") - if os.path.isfile(gguf_path): - gguf_path = os.path.dirname(gguf_path) - - self.safetensor_loader = None - - self.tensor_info = {} - self.gguf_path = gguf_path - self.tensor_file_map = {} - self.file_data_map = {} - self.gguf_file_meta = {} - self.tensor_device_map = {} - - # I know this is ugly, but I don't want to change the original code too much - # TODO: merge gguf load and other loads. - safetensor_loader = SafeTensorLoader(gguf_path) - if safetensor_loader.tensor_file_map: - self.safetensor_loader = safetensor_loader - return - # Walk through all the .gguf files in the directory - found_gguf = False - for root, dirs, files in os.walk(gguf_path): - for file in files: - if file.endswith(".gguf"): - found_gguf = True - file_name = os.path.join(root, file) - with open(file_name, "rb") as f: - self.load_gguf(f) - if file_name not in self.file_data_map: - self.file_data_map[file_name] = np.memmap(file_name, mode = 'r') - if not found_gguf: - raise FileNotFoundError(f"Cannot find any .gguf files in: {gguf_path}") - - def load_gguf(self, f): - f.seek(0) - assert f.read(4) == b'GGUF' - values = struct.unpack("torch.Tensor: - t = self.tensor_info[name] - if device.lower() == "cpu": - print(f"loading expert {expert_id} of {name} with CPU") - shape = t["shape"] - ggml_type = t["ggml_type"] - if ggml_type not in GGML_NAMES: - raise NotImplementedError(f"ggml_type {ggml_type} not implemented") - ggml_name = GGML_NAMES[ggml_type] - - # TODO: experts may fused in quant block, split it - assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant" - - blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name] - block_size = GGML_BLOCK_SIZES[ggml_name] - offset = expert_id * block_size * blocks_per_experts - data = data[offset: offset + block_size * blocks_per_experts] - - if "cuda" in device.lower(): - values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype) - else: - values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values.copy()) - - if ggml_name == "BF16": - values = values.view(torch.bfloat16) - values = values.view(shape[-2::-1]) - - return values - - def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor: - t = self.tensor_info[name] - if device.lower() == "cpu": - print(f"loading {name} with CPU") - if target_dtype == None: - target_dtype = torch.get_default_dtype() - - shape = t["shape"] - ggml_type = t["ggml_type"] - - if ggml_type not in GGML_NAMES: - raise NotImplementedError(f"ggml_type {ggml_type} not implemented") - - ggml_name = GGML_NAMES[ggml_type] - - data = self.get_mmap_tensor(name) - - block_size = GGML_BLOCK_SIZES[ggml_name] - elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name] - num_elements = int(np.prod(shape)) - num_blocks = num_elements // elements_per_block - - blocks_per_iter = 16384 - if num_blocks > blocks_per_iter: # dequant large tensor - values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device) - for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter): - blocks_begin = i * blocks_per_iter - blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) - if "cuda" in device.lower(): - cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) - else: - cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) - cur_values = torch.from_numpy(cur_values.copy()) - - cur_values = cur_values.view(-1, elements_per_block) - if ggml_name == "BF16": - cur_values = cur_values.view(torch.bfloat16) - values[blocks_begin : blocks_end] = cur_values - else: - if "cuda" in device.lower(): - values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) - else: - values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values) - - if ggml_name == "BF16": - values = values.view(torch.bfloat16) - - - values = values.view(shape[::-1]) - if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: - n_head = self.gguf_file_meta['llama.attention.head_count'] - values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) - .swapaxes(1, 2) - .reshape(values.shape)) - elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: - n_head = self.gguf_file_meta['llama.attention.head_count_kv'] - values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) - .swapaxes(1, 2) - .reshape(values.shape)) - return values def read_value(f, data_type): if data_type == DATA_TYPES["string"]: @@ -921,6 +689,7 @@ def translate_name_to_gguf(name): name = name.replace(".gate_up_proj.", ".up_proj") name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp") + name = name.replace(".mlp.gate.e_score_correction_bias", ".exp_probs_b.bias") name = name.replace(".mlp.gate", ".ffn_gate_inp") name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp") name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp") diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py index ecc09a0..003f93c 100644 --- a/ktransformers/util/custom_loader.py +++ b/ktransformers/util/custom_loader.py @@ -7,15 +7,39 @@ from typing import Sequence import os from enum import IntEnum import torch -import KTransformersOps +if not torch.xpu.is_available(): + import KTransformersOps from safetensors import safe_open from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant +from ktransformers.util.custom_gguf import * from safetensors.torch import save_file +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, Union -class SafeTensorLoader: +class ModelLoader(ABC): + """ + Abstract base class for model loaders. + Defines the interface that all model loaders must implement. + """ tensor_file_map = {} - tensor_type_map = {} - file_handle_map = {} + @abstractmethod + def has_tensor(cls, name: str): + """ + Check if the tensor exists in the loader. + + Args: + name: Name of the tensor to check + + Returns: + bool: True if the tensor exists, False otherwise + """ + pass + +class SafeTensorLoader(ModelLoader): + tensor_file_map: dict + tensor_type_map: dict + file_handle_map: dict + tensor_device_map: dict def __init__(self, file_path: str): self.__load_tensor_file_map(file_path) @@ -28,6 +52,10 @@ class SafeTensorLoader: folder_path = os.path.dirname(file_path) else: folder_path = file_path + self.file_handle_map = {} + self.tensor_file_map = {} + self.tensor_type_map = {} + self.tensor_device_map = {} found_safetensor = False for root, _, files in os.walk(folder_path): @@ -57,7 +85,11 @@ class SafeTensorLoader: # raise FileNotFoundError(f"No Safetensor files found in {folder_path}") def load_tensor(self, key: str, device: str="cpu"): - if key not in self.tensor_file_map: + if translate_name_to_gguf(key) in self.tensor_file_map: + key = translate_name_to_gguf(key) + elif key in self.tensor_file_map: + pass + else: raise KeyError(f"Key {key} not found in Safetensor files") file = self.tensor_file_map[key] f = self.file_handle_map.get(file) @@ -66,13 +98,145 @@ class SafeTensorLoader: tensor = f.get_tensor(key) return tensor.to(device) + def load_experts(self, key: str, device: str="cpu"): + ''' + Load experts from safetensor + key: the name of the experts + device: the device to load the experts to + return: dict, + {up: tensor, down: tensor, gate: tensor, up_type: int, down_type: int, gate_type: int} + {xxx}_type: the type of the up tensor, corresponding to the ggml type + ''' + if self.has_tensor(translate_name_to_gguf(key)+".ffn_gate_exps.weight"): + # legacy branch for loading hybrid model + base_key = translate_name_to_gguf(key) + # Load experts from safetensor + gate_key = f"{base_key}.ffn_gate_exps.weight" + gate_type_key = f"{base_key}.ffn_gate_exps.ggml_type" + up_key = f"{base_key}.ffn_up_exps.weight" + up_type_key = f"{base_key}.ffn_up_exps.ggml_type" + down_key = f"{base_key}.ffn_down_exps.weight" + down_type_key = f"{base_key}.ffn_down_exps.ggml_type" + gate_tensor = self.load_tensor(gate_key, device).numpy() + up_tensor = self.load_tensor(up_key, device).numpy() + down_tensor = self.load_tensor(down_key, device).numpy() + gate_type = self.load_tensor(gate_type_key, device).item() + up_type = self.load_tensor(up_type_key, device).item() + down_type = self.load_tensor(down_type_key, device).item() + + return { + "up": up_tensor, + "gate": gate_tensor, + "down": down_tensor, + "up_type": up_type, + "gate_type": gate_type, + "down_type": down_type + } + + else: + # Load experts from safetensor + base_key = key # e.g. "model.layers.3.mlp.experts" + experts_count = 0 + + # First, count how many experts we have by checking for expert 0's up_proj + while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight"): + experts_count += 1 + + if experts_count == 0: + raise ValueError(f"No experts found for key {base_key}") + + # Initialize empty lists to store tensors for each projection type + up_projs = [] + gate_projs = [] + down_projs = [] + + # Load all expert weights + for expert_id in range(experts_count): + up_key = f"{base_key}.{expert_id}.up_proj.weight" + gate_key = f"{base_key}.{expert_id}.gate_proj.weight" + down_key = f"{base_key}.{expert_id}.down_proj.weight" + + up_tensor = self.load_tensor(up_key, device) + gate_tensor = self.load_tensor(gate_key, device) + down_tensor = self.load_tensor(down_key, device) + + up_projs.append(up_tensor) + gate_projs.append(gate_tensor) + down_projs.append(down_tensor) + + # Stack the tensors along a new dimension + up_tensor = torch.stack(up_projs, dim=0) + gate_tensor = torch.stack(gate_projs, dim=0) + down_tensor = torch.stack(down_projs, dim=0) + + # Get original dtype for GGML type determination + orig_up_dtype = up_tensor.dtype + orig_gate_dtype = gate_tensor.dtype + orig_down_dtype = down_tensor.dtype + + # Convert to numpy with proper bfloat16 support + up_numpy = up_tensor.view(torch.uint16).numpy() + gate_numpy = gate_tensor.view(torch.uint16).numpy() + down_numpy = down_tensor.view(torch.uint16).numpy() + + # Determine tensor data types for GGML conversion + def get_ggml_type(dtype): + if dtype == torch.float32: + return GGMLQuantizationType.F32 + elif dtype == torch.float16: + return GGMLQuantizationType.F16 + elif dtype == torch.bfloat16: + return GGMLQuantizationType.BF16 + else: + raise ValueError(f"Unsupported tensor dtype: {dtype}") + + return { + "up": up_numpy, + "gate": gate_numpy, + "down": down_numpy, + "up_type": get_ggml_type(orig_up_dtype), + "gate_type": get_ggml_type(orig_gate_dtype), + "down_type": get_ggml_type(orig_down_dtype) + } + + def load_gate(self, key: str, device: str="cpu"): + ''' + Load gate from safetensor + key: the name of the gate + device: the device to load the gate to + return: dict, + {'weight': tensor, 'e_score_correction_bias': tensor} + ''' + target = ["weight", "e_score_correction_bias"] + res = {'weight': None, 'e_score_correction_bias': None} + if self.has_tensor(translate_name_to_gguf(key)+".ffn_gate_exps.weight"): + # legacy branch for loading hybrid model + base_key = key + for k in target: + translated_key = translate_name_to_gguf(f"{base_key}.{k}") + if self.has_tensor(translated_key): + tensor = self.load_tensor(translated_key, device) + res[k] = tensor + else: + # Load gate from safetensor + base_key = key + for k in target: + if self.has_tensor(f"{base_key}.{k}"): + tensor = self.load_tensor(f"{base_key}.{k}", device) + res[k] = tensor + return res + def close_all_handles(self): for handle in self.file_handle_map.values(): handle.close() self.file_handle_map.clear() def load_dequantized_tensor(self, key:str, device: str="cpu"): - if key not in self.tensor_file_map: + if key in self.tensor_file_map and translate_name_to_gguf(key): + pass + elif translate_name_to_gguf(key) in self.tensor_file_map: + key = translate_name_to_gguf(key) + else: raise KeyError(f"Key {key} not found in Safetensor files") file = self.tensor_file_map[key] f = self.file_handle_map.get(file) @@ -83,4 +247,320 @@ class SafeTensorLoader: if key[:-7] + ".weight_scale_inv" in self.tensor_file_map: weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device) tensor = weight_dequant(tensor, weight_scale_inv) - return tensor.to(device) \ No newline at end of file + return tensor.to(device) + + def has_tensor(self, name: str): + return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map + +class GGUFLoader(ModelLoader): + tensor_info: dict + gguf_path: str + tensor_file_map: dict # {tensor_name: tensor_file_path} + gguf_file_meta: dict + safetensor_loader: SafeTensorLoader + def __init__(self, gguf_path: str): + # Check dir exist + if not os.path.exists(gguf_path): + raise FileNotFoundError(f"GGUF dir not found: {gguf_path}") + if os.path.isfile(gguf_path): + gguf_path = os.path.dirname(gguf_path) + + self.safetensor_loader = None + + self.tensor_info = {} + self.gguf_path = gguf_path + self.tensor_file_map = {} + self.file_data_map = {} + self.gguf_file_meta = {} + self.tensor_device_map = {} + + # Walk through all the .gguf files in the directory + found_gguf = False + for root, dirs, files in os.walk(gguf_path): + for file in files: + if file.endswith(".gguf"): + found_gguf = True + file_name = os.path.join(root, file) + with open(file_name, "rb") as f: + self.load_gguf(f) + if file_name not in self.file_data_map: + self.file_data_map[file_name] = np.memmap(file_name, mode = 'r') + if not found_gguf: + raise FileNotFoundError(f"Cannot find any .gguf files in: {gguf_path}") + + def load_gguf(self, f): + f.seek(0) + assert f.read(4) == b'GGUF' + values = struct.unpack("torch.Tensor: + name = translate_name_to_gguf(name) + t = self.tensor_info[name] + shape = t["shape"] + ggml_type = t["ggml_type"] + if ggml_type not in GGML_NAMES: + raise NotImplementedError(f"ggml_type {ggml_type} not implemented") + ggml_name = GGML_NAMES[ggml_type] + + # TODO: experts may fused in quant block, split it + assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant" + + blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name] + block_size = GGML_BLOCK_SIZES[ggml_name] + offset = expert_id * block_size * blocks_per_experts + data = data[offset: offset + block_size * blocks_per_experts] + + if "cuda" in device.lower(): + values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype) + else: + values = GGML_DEQUANTIZE[ggml_name](data) + values = torch.from_numpy(values.copy()) + + if ggml_name == "BF16": + values = values.view(torch.bfloat16) + values = values.view(shape[-2::-1]) + + return values + + def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor: + name = translate_name_to_gguf(name) + t = self.tensor_info[name] + if target_dtype == None: + target_dtype = torch.get_default_dtype() + + shape = t["shape"] + ggml_type = t["ggml_type"] + + if ggml_type not in GGML_NAMES: + raise NotImplementedError(f"ggml_type {ggml_type} not implemented") + + ggml_name = GGML_NAMES[ggml_type] + + data = self.get_mmap_tensor(name) + + block_size = GGML_BLOCK_SIZES[ggml_name] + elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name] + num_elements = int(np.prod(shape)) + num_blocks = num_elements // elements_per_block + + blocks_per_iter = 16384 + if num_blocks > blocks_per_iter: # dequant large tensor + values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device) + for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter): + blocks_begin = i * blocks_per_iter + blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) + if "cuda" in device.lower(): + try: + cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) + except: + cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) + cur_values = torch.from_numpy(cur_values.copy()).to(device) + else: + cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) + cur_values = torch.from_numpy(cur_values.copy()) + + cur_values = cur_values.view(-1, elements_per_block) + if ggml_name == "BF16": + cur_values = cur_values.view(torch.bfloat16) + values[blocks_begin : blocks_end] = cur_values + else: + if "cuda" in device.lower(): + values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) + else: + np_values = np.copy(GGML_DEQUANTIZE[ggml_name](data)) + values = torch.from_numpy(np_values).to(device) + del np_values + + if ggml_name == "BF16": + values = values.view(torch.bfloat16) + + + values = values.view(shape[::-1]) + if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: + n_head = self.gguf_file_meta['llama.attention.head_count'] + values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) + .swapaxes(1, 2) + .reshape(values.shape)) + elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: + n_head = self.gguf_file_meta['llama.attention.head_count_kv'] + values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) + .swapaxes(1, 2) + .reshape(values.shape)) + return values + def has_tensor(self, name: str): + name = translate_name_to_gguf(name) + return name in self.tensor_info + + def get_ggml_type(self, name: str): + name = translate_name_to_gguf(name) + if name not in self.tensor_info: + raise KeyError(f"Key {name} not found in GGUF files") + return self.tensor_info[name]["ggml_type"] + +class ModelLoaderFactory: + """ + Factory class for creating model loaders. + Automatically detects the model format based on file extensions in the directory. + """ + + @staticmethod + def create_loader(path: str): + """ + Create a model loader for the given path by detecting the model format. + The function checks for the presence of .safetensors or .gguf files + in the specified path and creates the appropriate loader. + + Args: + path: Path to the model directory or file + + Returns: + An appropriate ModelLoader instance (SafeTensorLoader or GGUFLoader) + + Raises: + FileNotFoundError: If no supported model files are found in the path + """ + if not os.path.exists(path): + raise FileNotFoundError(f"Path not found: {path}") + + # Normalize to directory path if a file was provided + if os.path.isfile(path): + if path.endswith(".safetensors"): + return SafeTensorLoader(path) + elif path.endswith(".gguf"): + return GGUFLoader(path) + else: + folder_path = os.path.dirname(path) + else: + folder_path = path + + # Check for safetensors files + has_safetensors = False + has_gguf = False + + for root, _, files in os.walk(folder_path): + for file in files: + if file.endswith(".safetensors"): + has_safetensors = True + break + elif file.endswith(".gguf"): + has_gguf = True + break + if has_safetensors or has_gguf: + break + + # Create the appropriate loader based on detected file types + # Prioritize SafeTensor over GGUF if both are present + if has_safetensors: + try: + return SafeTensorLoader(folder_path) + except Exception as e: + print(f"Failed to create SafeTensorLoader: {e}") + # Fall through to try GGUF if SafeTensor fails + if not has_gguf: + raise + + if has_gguf: + try: + return GGUFLoader(folder_path) + except Exception as e: + print(f"Failed to create GGUFLoader: {e}") + raise + + # No supported model files found + raise FileNotFoundError(f"No .safetensors or .gguf files found in: {folder_path}") \ No newline at end of file diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 30f8880..98a44f2 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -11,13 +11,24 @@ from torch import nn import itertools import time import enum -from ktransformers.util.custom_gguf import translate_name_to_gguf -from ktransformers.util.custom_gguf import GGUFLoader +from transformers import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + MinPLogitsWarper, + TypicalLogitsWarper, + EpsilonLogitsWarper, + EtaLogitsWarper, +) + +from ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, GGUFLoader, translate_name_to_gguf from ktransformers.operators import base_operator from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.textstream import TextStreamer -from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton +if not torch.xpu.is_available(): + from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton import socket warm_uped = False @@ -49,6 +60,8 @@ def get_compute_capability(device:torch.device = None): return min_compute_capability_major else: return torch.cuda.get_device_properties(device) + else: + return 0 def set_module(model, submodule_key, module): tokens = submodule_key.split('.') @@ -87,45 +100,132 @@ def get_all_used_cuda_device(device_map:dict): all_device_list = list(all_device_list) return all_device_list -def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""): +def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"): prefix = prefix.replace("orig_module.", "") persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} for name, param in local_state.items(): key = prefix + name - translated_key = translate_name_to_gguf(key) + translated_key = key # TODO: Merge all loader. # I know this is ugly but lets do it for now. - if gguf_loader.safetensor_loader is not None: - load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor - tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map + if isinstance(gguf_loader, SafeTensorLoader): + load_dequantized_tensor = gguf_loader.load_dequantized_tensor else: load_dequantized_tensor = gguf_loader.load_gguf_tensor tensor_file_map = gguf_loader.tensor_file_map - if translated_key in tensor_file_map: + if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key: target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") - torch.cuda.empty_cache() - weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) - set_param(module, name, weights) - del weights + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() + if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key): + attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype) + attn_k_b = attn_k_b.transpose(1, 2).contiguous() + attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype) + kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1) + kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous() + set_param(module, name, kv_b_proj) + del attn_k_b + del attn_v_b + else: + weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) + set_param(module, name, weights) + del weights else: #print(load_config.tensor_file_map.keys()) raise Exception(f"can't find {translated_key} in GGUF file!") -def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): + +def sync_all_device(all_device_list): + for device in all_device_list: + if "cuda" in device.lower(): + torch.cuda.synchronize(device) + elif "xpu" in device.lower(): + torch.xpu.synchronize(device) + else: + raise RuntimeError("The device {} is not available".format(device)) + +torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"} + +def xpu_fp16_model(config): + # This function is to check if we run this model on XPU with FP16 dtype + if not torch.xpu.is_available(): + return False + if config.architectures[0] == "DeepseekV3ForCausalLM": + return True + if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096: + # Qwen3-30B seems have precision issue with FP16 + # so we only use FP16 for Qwen3-235B now + return True + return False + +def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"): #print(f"recursively loading weights {prefix}") if not isinstance(module, base_operator.BaseInjectedModule): - load_cur_state_dict(module, gguf_loader, prefix) + load_cur_state_dict(module, gguf_loader, prefix, device=device) for name, child in module._modules.items(): - load_weights(child, gguf_loader, prefix+name+".") + load_weights(child, gguf_loader, prefix+name+".", device=device) else: module.load() +def tf_logits_warper(generation_config): + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances + used for multinomial sampling. + """ + + # instantiate warpers list + warpers = LogitsProcessorList() + + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config._eos_token_tensor, list): + min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 + elif isinstance(generation_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.min_p is not None: + # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) + warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + warpers.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + warpers.append( + EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + warpers.append( + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device + ) + ) + # `LogitNormalization` should always be the last logit processor, when present + if generation_config.renormalize_logits is True: + warpers.append(LogitNormalization()) + return warpers + def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False, num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None): @@ -134,8 +234,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud torch._dynamo.config.suppress_errors = True batch_size, seq_length = inputs.shape device_map = model.gguf_loader.tensor_device_map - torch_device = get_device('blk.0.self_attn', device_map) - torch_device = "cuda:0" if torch_device == "cuda" else torch_device + torch_device = get_device('model.layers.0.self_attn', device_map) + torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device inputs = inputs.to(torch_device) all_cuda_device = get_all_used_cuda_device(device_map) @@ -148,7 +248,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud logits = cuda_graph_runner(cur_token, position_ids, cache_position) else: # custom_stream = torch.cuda.Stream() - torch.cuda.set_device(torch_device) + if torch.cuda.is_available(): + torch.cuda.set_device(torch_device) + elif torch.xpu.is_available(): + torch.xpu.set_device(torch_device) + else: + raise RuntimeError(f"The device: {torch_device} is not available") inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device) # with torch.cuda.stream(custom_stream): logits=model(inputs_embeds=inputs_embeds, @@ -156,10 +261,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True)[0] - if past_key_values != None: + if past_key_values != None and isinstance(past_key_values, StaticCache): past_key_values.change_seq_length(1) - for device in all_cuda_device: - torch.cuda.synchronize(device) + sync_all_device(all_cuda_device) #print(logits) next_token_scores = logits_warper(inputs, logits[:, -1, :]) if generation_config.do_sample: @@ -185,11 +289,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud return logits - torch.cuda.set_device(torch_device) + if torch.cuda.is_available(): + torch.cuda.set_device(torch_device) + elif torch.xpu.is_available(): + torch.xpu.set_device(torch_device) + else: + raise RuntimeError(f"The device: {torch_device} is not available") with torch.no_grad(): stream = TextStreamer(tokenizer) - if mode != 'long_context': + if torch.xpu.is_available(): + from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache + if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]: + past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None) + else: + past_key_values = DynamicNormalCache.from_legacy_cache(None) + elif mode != 'long_context': past_key_values = StaticCache( config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype ) @@ -201,14 +316,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud # change this to modify generate config #top_k=5, top_p=0.85, temperature=0.1 ) - try: # transformers==4.43 - logits_warper = ( - model._get_logits_warper(generation_config,device=inputs.device) - ) - except: - logits_warper = ( - model._get_logits_warper(generation_config) - ) + + logits_warper = tf_logits_warper(generation_config) cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32) generated_ids = torch.zeros( diff --git a/ktransformers/util/weight_loader.py b/ktransformers/util/weight_loader.py new file mode 100644 index 0000000..9dda646 --- /dev/null +++ b/ktransformers/util/weight_loader.py @@ -0,0 +1,367 @@ +from abc import ABC, abstractmethod +import os +import torch +import numpy as np +from safetensors import safe_open +from typing import Dict, Any, Optional, Union + +class ModelLoader(ABC): + """ + Abstract base class for model loaders. + Defines the interface that all model loaders must implement. + """ + + @abstractmethod + def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor: + """ + Load a tensor by name. + + Args: + name: Name of the tensor to load + device: Device to load the tensor to + + Returns: + The loaded tensor + """ + pass + + @classmethod + @abstractmethod + def supports_format(cls, path: str) -> bool: + """ + Check if this loader supports the given path format. + + Args: + path: Path to check + + Returns: + True if this loader supports the given path, False otherwise + """ + pass + + +class SafeTensorLoader(ModelLoader): + """ + Loader for SafeTensor format models. + """ + + def __init__(self, path: str): + """ + Initialize the SafeTensor loader. + + Args: + path: Path to the model directory or file + """ + self.tensor_file_map = {} # Maps tensor names to file paths + self.file_handle_map = {} # Maps file names to file handles + self._load_tensor_file_map(path) + + def _load_tensor_file_map(self, path: str) -> None: + """ + Load the tensor file map from the given path. + + Args: + path: Path to the model directory or file + """ + # Normalize path to directory + if not os.path.exists(path): + raise FileNotFoundError(f"Path not found: {path}") + if os.path.isfile(path): + folder_path = os.path.dirname(path) + else: + folder_path = path + + found_safetensor = False + for root, _, files in os.walk(folder_path): + files = sorted(files) + for file in files: + if file.endswith(".safetensors"): + found_safetensor = True + file_path = os.path.join(root, file) + if file not in self.file_handle_map: + try: + handle = safe_open(file_path, framework="pt") + self.file_handle_map[file] = handle + except Exception as e: + print(f"Error opening Safetensor file {file_path}: {e}") + continue + + f = self.file_handle_map.get(file) + if f is None: + continue + try: + for key in f.keys(): + self.tensor_file_map[key] = file + except Exception as e: + print(f"Error reading Safetensor file {file_path}: {e}") + + if not found_safetensor: + # Not raising an error here allows for the factory to try other loaders + print(f"No Safetensor files found in {folder_path}") + + def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor: + """ + Load a tensor by name. + + Args: + name: Name of the tensor to load + device: Device to load the tensor to + + Returns: + The loaded tensor + """ + if name not in self.tensor_file_map: + raise KeyError(f"Key {name} not found in Safetensor files") + file = self.tensor_file_map[name] + f = self.file_handle_map.get(file) + if f is None: + raise FileNotFoundError(f"File {file} not found in Safetensor files") + tensor = f.get_tensor(name) + return tensor.to(device) + + def load_dequantized_tensor(self, name: str, device: str = "cpu") -> torch.Tensor: + """ + Load and dequantize a tensor. + + Args: + name: Name of the tensor to load + device: Device to load the tensor to + + Returns: + The dequantized tensor + """ + if name not in self.tensor_file_map: + raise KeyError(f"Key {name} not found in Safetensor files") + file = self.tensor_file_map[name] + f = self.file_handle_map.get(file) + if f is None: + raise FileNotFoundError(f"File {file} not found in Safetensor files") + tensor = f.get_tensor(name).to(device) + if name.endswith(".weight"): + if name[:-7] + ".weight_scale_inv" in self.tensor_file_map: + weight_scale_inv = f.get_tensor(name[:-7] + ".weight_scale_inv").to(device) + # Assuming weight_dequant function is imported + from ktransformers.ktransformers_ext.triton.fp8gemm import weight_dequant + tensor = weight_dequant(tensor, weight_scale_inv) + return tensor.to(device) + + def close_all_handles(self) -> None: + """ + Close all file handles. + """ + for handle in self.file_handle_map.values(): + handle.close() + self.file_handle_map.clear() + + @classmethod + def supports_format(cls, path: str) -> bool: + """ + Check if this loader supports the given path format. + + Args: + path: Path to check + + Returns: + True if safetensor files are found in the path, False otherwise + """ + # Normalize path to directory + if not os.path.exists(path): + return False + if os.path.isfile(path): + if path.endswith(".safetensors"): + return True + folder_path = os.path.dirname(path) + else: + folder_path = path + + # Check if any safetensor files exist in the folder + for root, _, files in os.walk(folder_path): + for file in files: + if file.endswith(".safetensors"): + return True + return False + + +class GGUFLoader(ModelLoader): + """ + Loader for GGUF format models. + """ + + def __init__(self, path: str): + """ + Initialize the GGUF loader. + + Args: + path: Path to the model directory or file + """ + # Check if path exists + if not os.path.exists(path): + raise FileNotFoundError(f"GGUF dir not found: {path}") + if os.path.isfile(path): + self.gguf_path = os.path.dirname(path) + else: + self.gguf_path = path + + self.tensor_info = {} # Stores tensor metadata + self.tensor_file_map = {} # Maps tensor names to file paths + self.file_data_map = {} # Maps file paths to memory-mapped data + self.gguf_file_meta = {} # Stores GGUF metadata + + # For compatibility with the factory pattern + self.safetensor_loader = None + + # Scan all GGUF files in the directory + found_gguf = False + for root, _, files in os.walk(self.gguf_path): + for file in files: + if file.endswith(".gguf"): + found_gguf = True + file_path = os.path.join(root, file) + with open(file_path, "rb") as f: + self._load_gguf(f) + if file_path not in self.file_data_map: + self.file_data_map[file_path] = np.memmap(file_path, mode='r') + + if not found_gguf: + raise FileNotFoundError(f"Cannot find any .gguf files in: {self.gguf_path}") + + def _load_gguf(self, f) -> None: + """ + Load GGUF file metadata and tensor info. + + Args: + f: File handle of the GGUF file + """ + # Implementation should follow the original GGUFLoader._load_gguf + # This is a simplified version for illustration + f.seek(0) + assert f.read(4) == b'GGUF' + + # Read header + values = struct.unpack(" Any: + """ + Read a value from the file according to its data type. + + Args: + f: File handle + data_type: Type of data to read + + Returns: + The read value + """ + # Simplified implementation + # In a complete implementation, this would handle all data types + if data_type == 8: # DATA_TYPES["string"] + length = struct.unpack(" torch.Tensor: + """ + Load a tensor by name. + + Args: + name: Name of the tensor to load + device: Device to load the tensor to + + Returns: + The loaded tensor + """ + # This should call load_gguf_tensor with the appropriate parameters + return self.load_gguf_tensor(name, device) + + def load_gguf_tensor(self, name: str, device: str = "cpu", target_dtype = None) -> torch.Tensor: + """ + Load a GGUF tensor by name. + + Args: + name: Name of the tensor to load + device: Device to load the tensor to + target_dtype: Target data type for the tensor + + Returns: + The loaded tensor + """ + # Implementation would follow the original GGUFLoader.load_gguf_tensor + # This is a placeholder for illustration + if name not in self.tensor_info: + raise KeyError(f"Tensor {name} not found") + + # Actual implementation would dequantize the tensor data + # and return a torch.Tensor + return torch.zeros(1, device=device) # Placeholder + + @classmethod + def supports_format(cls, path: str) -> bool: + """ + Check if this loader supports the given path format. + + Args: + path: Path to check + + Returns: + True if GGUF files are found in the path, False otherwise + """ + # Normalize path to directory + if not os.path.exists(path): + return False + if os.path.isfile(path): + return path.endswith(".gguf") + + # Check if any GGUF files exist in the folder + for root, _, files in os.walk(path): + for file in files: + if file.endswith(".gguf"): + return True + return False \ No newline at end of file diff --git a/merge_tensors/merge_safetensor_gguf.py b/merge_tensors/merge_safetensor_gguf.py index efeab3b..f299ab9 100644 --- a/merge_tensors/merge_safetensor_gguf.py +++ b/merge_tensors/merge_safetensor_gguf.py @@ -6,7 +6,7 @@ import sys # sys.path.insert(0, "/home/azure/ktransformers") import argparse import torch -from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf +from ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf from safetensors import safe_open from safetensors.torch import save_file import re diff --git a/pyproject.toml b/pyproject.toml index fbf0924..9502c55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ "build", "fire", "protobuf", - "triton >= 3.2" ] requires-python = ">=3.10" @@ -70,7 +69,7 @@ ktransformers = "ktransformers.server.main:main" [tool.setuptools.packages.find] where = ["./", ] -include = ["ktransformers"] +include = ["ktransformers","ktransformers.*"] [tool.black] line-length = 120 preview = true diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index dd3a206..25afaef 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -7,4 +7,3 @@ cpufeature; sys_platform == 'win32' or sys_platform == 'Windows' protobuf tiktoken blobfile -triton>=3.2 diff --git a/setup.py b/setup.py index 8ea6797..25e8a35 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,17 @@ try: from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME except ImportError: MUSA_HOME=None - +KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available() + +# 检测 DEV_BACKEND 环境变量 +dev_backend = os.environ.get("DEV_BACKEND", "").lower() +if dev_backend == "xpu": + triton_dep = [ + "pytorch-triton-xpu==3.3.0" + ] +else: + triton_dep = ["triton>=3.2"] + with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1" class CpuInstructInfo: @@ -241,8 +251,10 @@ class VersionInfo: backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" elif ROCM_HOME is not None: backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}" + elif torch.xpu.is_available(): + backend_version = f"xpu" else: - raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.") + raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set and XPU is not available.") package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}" if full_version: return package_version @@ -511,8 +523,10 @@ class CMakeBuild(BuildExtension): cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"] elif ROCM_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"] + elif KTRANSFORMERS_BUILD_XPU: + cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"] else: - raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.") + raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.") cmake_args = get_cmake_abi_args(cmake_args) # log cmake_args @@ -636,33 +650,41 @@ elif MUSA_HOME is not None: ] } ) +elif torch.xpu.is_available(): #XPUExtension is not available now. + ops_module = None else: - raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") + raise ValueError("Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.") -ext_modules = [ - CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), - ops_module, - CUDAExtension( - 'vLLMMarlin', [ - 'csrc/custom_marlin/binding.cpp', - 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', - 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', - ], - extra_compile_args={ - 'cxx': ['-O3'], - 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], - }, - ) -] -if with_balance: - print("using balance_serve") - ext_modules.append( - CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) - ) +if not torch.xpu.is_available(): + ext_modules = [ + CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), + ops_module, + CUDAExtension( + 'vLLMMarlin', [ + 'csrc/custom_marlin/binding.cpp', + 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', + 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', + ], + extra_compile_args={ + 'cxx': ['-O3'], + 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], + }, + ) + ] + if with_balance: + print("using balance_serve") + ext_modules.append( + CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) + ) +else: + ext_modules = [ + CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), + ] setup( name=VersionInfo.PACKAGE_NAME, version=VersionInfo().get_package_version(), + install_requires=triton_dep, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, ext_modules=ext_modules )