diff --git a/Dockerfile.xpu b/Dockerfile.xpu
new file mode 100644
index 0000000..bb4d2dd
--- /dev/null
+++ b/Dockerfile.xpu
@@ -0,0 +1,68 @@
+# Base image
+FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
+
+ARG http_proxy
+ARG https_proxy
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV CONDA_DIR=/opt/conda
+
+# Install dependencies
+RUN apt-get update && apt-get install -y \
+ wget \
+ curl \
+ bash \
+ git \
+ vim \
+ ca-certificates \
+ binutils \
+ cmake \
+ g++ \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install Miniforge
+RUN wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O /tmp/miniforge.sh && \
+ bash /tmp/miniforge.sh -b -p $CONDA_DIR && \
+ rm /tmp/miniforge.sh && \
+ $CONDA_DIR/bin/conda clean -afy
+
+# Add conda to PATH
+ENV PATH=$CONDA_DIR/bin:$PATH
+
+RUN bash -c "\
+ source /opt/conda/etc/profile.d/conda.sh && \
+ conda create --name ktransformers python=3.11 -y && \
+ conda activate ktransformers && \
+ conda env list && \
+ conda install -c conda-forge libstdcxx-ng -y && \
+ strings \$(find /opt/conda/envs/ktransformers/lib -name 'libstdc++.so.6') | grep GLIBCXX | grep 3.4.32 \
+"
+
+RUN bash -c "\
+ source /opt/conda/etc/profile.d/conda.sh && \
+ conda activate ktransformers && \
+ pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu && \
+ pip uninstall -y torch torchvision torchaudio && \
+ pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu && \
+ pip uninstall -y intel-opencl-rt dpcpp-cpp-rt && \
+ pip list \
+"
+
+# Clone and set up ktransformers repo
+RUN bash -c "\
+ source $CONDA_DIR/etc/profile.d/conda.sh && \
+ conda activate ktransformers && \
+ git clone https://github.com/kvcache-ai/ktransformers.git && \
+ cd ktransformers && \
+ git submodule update --init && \
+ sed -i 's/torch\.xpu\.is_available()/True/g' setup.py && \
+ bash install.sh --dev xpu \
+"
+
+# Init conda and prepare bashrc
+RUN conda init bash && \
+ echo "source $CONDA_DIR/etc/profile.d/conda.sh" >> ~/.bashrc && \
+ echo "conda activate ktransformers" >> ~/.bashrc
+
+WORKDIR /ktransformers/
+CMD ["bash"]
diff --git a/README.md b/README.md
index 6ad3a59..b5f62da 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).
+
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2
@@ -116,6 +118,16 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
Getting started with KTransformers is simple! Follow the steps below to set up and start using it.
+we have already supported vendors:
+
+- Metax
+- Sanechips (ZhuFeng V1.0)
+- Intel
+- Ascend
+- Kunpeng
+- AMD
+
+
### 📥 Installation
To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).
diff --git a/WeChatGroup.png b/WeChatGroup.png
index c7f3c2d..8a53460 100644
Binary files a/WeChatGroup.png and b/WeChatGroup.png differ
diff --git a/csrc/ktransformers_ext/CMakeLists.txt b/csrc/ktransformers_ext/CMakeLists.txt
index 217de78..0ed4ef4 100644
--- a/csrc/ktransformers_ext/CMakeLists.txt
+++ b/csrc/ktransformers_ext/CMakeLists.txt
@@ -41,6 +41,7 @@ option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
+option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
@@ -303,6 +304,8 @@ elseif (UNIX)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
+ elseif (KTRANSFORMERS_USE_XPU)
+ add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
else()
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
@@ -361,6 +364,7 @@ elseif(UNIX)
message(STATUS "Building for HIP")
elseif(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
+ elseif(KTRANSFORMERS_USE_XPU)
else()
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
diff --git a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h
index 9c7e781..7b1d898 100644
--- a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h
+++ b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h
@@ -17,6 +17,7 @@
#include
#include
#include
+ #include
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
@@ -66,10 +67,14 @@
}
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) {
+ #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
+ #else
+ throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma");
+ #endif
}
static void sync_(void* cpu_infer_ptr) {
@@ -78,7 +83,11 @@
}
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
+ #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
+ #else
+ throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma");
+ #endif
}
public:
diff --git a/csrc/ktransformers_ext/cuda/test_dequant.py b/csrc/ktransformers_ext/cuda/test_dequant.py
index abca745..c39d6c7 100644
--- a/csrc/ktransformers_ext/cuda/test_dequant.py
+++ b/csrc/ktransformers_ext/cuda/test_dequant.py
@@ -1,7 +1,7 @@
import os
import sys
sys.path.insert(0,"/home/zbx/ktransformers")
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader
import torch
gguf_loader_1 = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp
index 2767679..f0aeaa5 100644
--- a/csrc/ktransformers_ext/ext_bindings.cpp
+++ b/csrc/ktransformers_ext/ext_bindings.cpp
@@ -9,7 +9,7 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
-#ifndef KTRANSFORMERS_USE_ROCM
+#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)
#include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h"
diff --git a/csrc/ktransformers_ext/operators/amx/la/amx.hpp b/csrc/ktransformers_ext/operators/amx/la/amx.hpp
index 3338e09..866300d 100644
--- a/csrc/ktransformers_ext/operators/amx/la/amx.hpp
+++ b/csrc/ktransformers_ext/operators/amx/la/amx.hpp
@@ -843,7 +843,7 @@ inline void mat_mul(int m, int n, int k, std::shared_ptrget_submat(m, k, m_begin, k_block_begin + k_begin);
__m512bh *b512 = (__m512bh *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
- for (int m_i = 0; m_i < m; m_i++) {
+ for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@@ -914,7 +914,7 @@ inline void mat_mul(int m, int n, int k, std::shared_ptrget_submat(m, k, m_begin, k_block_begin + k_begin);
__m512i *b512 = (__m512i *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
- for (int m_i = 0; m_i < m; m_i++) {
+ for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
diff --git a/csrc/ktransformers_ext/operators/amx/moe.hpp b/csrc/ktransformers_ext/operators/amx/moe.hpp
index 7e966ae..81df642 100644
--- a/csrc/ktransformers_ext/operators/amx/moe.hpp
+++ b/csrc/ktransformers_ext/operators/amx/moe.hpp
@@ -272,8 +272,8 @@ public:
void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output,
int *batch_size_tensor, Backend *backend) {
- bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);
qlen = batch_size_tensor[0];
+ bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);
int activated_expert = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_num_[i] = 0;
@@ -395,4 +395,4 @@ public:
}
};
-#endif
\ No newline at end of file
+#endif
diff --git a/doc/README.md b/doc/README.md
index 05df2d3..199b990 100644
--- a/doc/README.md
+++ b/doc/README.md
@@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./en/xpu.md)).
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./en/llama4.md)).
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./en/balance-serve.md)).
* **Mar 27, 2025**: Support Multi-concurrency.
diff --git a/doc/en/Docker_xpu.md b/doc/en/Docker_xpu.md
new file mode 100644
index 0000000..cb92d01
--- /dev/null
+++ b/doc/en/Docker_xpu.md
@@ -0,0 +1,94 @@
+# Intel GPU Docker Guide (Beta)
+
+## Prerequisites
+
+* Docker must be installed and running on your system.
+* Create a folder to store big models & intermediate files (e.g., /mnt/models)
+* **Before proceeding, ensure the Intel GPU driver is installed correctly on your host:** [Installation Guide](./xpu.md#1-install-intel-gpu-driver)
+
+---
+
+## Building the Docker Image Locally
+
+1. Clone the repository and navigate to the project directory:
+
+ ```bash
+ git clone https://github.com/kvcache-ai/ktransformers.git
+ cd ktransformers
+ ```
+
+2. Build the Docker image using the XPU-specific [Dockerfile.xpu](../../Dockerfile.xpu):
+
+ ```bash
+ sudo http_proxy=$HTTP_PROXY \
+ https_proxy=$HTTPS_PROXY \
+ docker build \
+ --build-arg http_proxy=$HTTP_PROXY \
+ --build-arg https_proxy=$HTTPS_PROXY \
+ -t kt_xpu:0.3.1 \
+ -f Dockerfile.xpu \
+ .
+ ```
+
+---
+
+## Running the Container
+
+### 1. Start the container
+
+```bash
+sudo docker run -td --privileged \
+ --net=host \
+ --device=/dev/dri \
+ --shm-size="16g" \
+ -v /path/to/models:/models \
+ -e http_proxy=$HTTP_PROXY \
+ -e https_proxy=$HTTPS_PROXY \
+ --name ktransformers_xpu \
+ kt_xpu:0.3.1
+```
+
+**Note**: Replace `/path/to/models` with your actual model directory path (e.g., `/mnt/models`).
+
+---
+
+### 2. Access the container
+
+```bash
+sudo docker exec -it ktransformers_xpu /bin/bash
+```
+
+---
+
+### 3. Set required XPU environment variables (inside the container)
+
+```bash
+export SYCL_CACHE_PERSISTENT=1
+export ONEAPI_DEVICE_SELECTOR=level_zero:0
+export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+```
+
+---
+
+### 4. Run the sample script
+
+```bash
+python ktransformers/local_chat.py \
+ --model_path deepseek-ai/DeepSeek-R1 \
+ --gguf_path \
+ --optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \
+ --cpu_infer \
+ --device xpu \
+ --max_new_tokens 200
+```
+
+**Note**:
+
+* Replace `` with the path to your GGUF model files.
+* Replace `` with the number of CPU cores you want to use plus one.
+
+---
+
+## Additional Information
+
+For more configuration options and usage details, refer to the [project README](../../README.md). To run KTransformers natively on XPU (outside of Docker), please refer to [xpu.md](./xpu.md).
diff --git a/doc/en/install.md b/doc/en/install.md
index 031b541..49cc0f9 100644
--- a/doc/en/install.md
+++ b/doc/en/install.md
@@ -45,7 +45,7 @@ Some preparation:
sudo apt-get update
sudo apt-get install build-essential cmake ninja-build patchelf
```
-- We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX-3.4.32`
+- We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX_3.4.32`
```sh
conda create --name ktransformers python=3.11
diff --git a/doc/en/xpu.md b/doc/en/xpu.md
new file mode 100644
index 0000000..78a1923
--- /dev/null
+++ b/doc/en/xpu.md
@@ -0,0 +1,134 @@
+# Intel GPU Support for KTransformers (Beta)
+
+## Introduction
+
+### Overview
+We are excited to introduce **Intel GPU support** in KTransformers (Beta release). This implementation has been tested and developed using Intel Xeon Scalable processors and Intel Arc GPUs (such as A770 and B580).
+
+## Installation Guide
+
+### 1. Install Intel GPU Driver
+Begin by installing the GPU drivers for your Intel GPU:
+- [Official GPU Installation Guide for Intel GPUs](https://dgpu-docs.intel.com/driver/overview.html)
+
+To verify that the kernel and compute drivers are installed and functional:
+
+```bash
+clinfo --list | grep Device
+ `-- Device #0: 13th Gen Intel(R) Core(TM) i9-13900K
+ `-- Device #0: Intel(R) Arc(TM) A770 Graphics
+ `-- Device #0: Intel(R) UHD Graphics 770
+```
+
+> [!Important]
+> Ensure that **Resizable BAR** is enabled in your system's BIOS before proceeding. This is essential for optimal GPU performance and to avoid potential issues such as `Bus error (core dumped)`. For detailed steps, please refer to the official guidance [here](https://www.intel.com/content/www/us/en/support/articles/000090831/graphics.html).
+
+### 2. Set Up Conda Environment
+We recommend using Miniconda3/Anaconda3 for environment management:
+
+```bash
+# Download Miniconda
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
+
+# Create environment
+conda create --name ktransformers python=3.11
+conda activate ktransformers
+
+# Install required libraries
+conda install -c conda-forge libstdcxx-ng
+
+# Verify GLIBCXX version (should include 3.4.32)
+strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
+```
+
+> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
+
+### 3. Install PyTorch and IPEX-LLM
+Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm):
+
+```bash
+pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu
+pip uninstall torch torchvision torchaudio
+pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7
+pip uninstall intel-opencl-rt dpcpp-cpp-rt
+```
+
+### 4. Build ktransformers
+
+```bash
+# Clone repository
+git clone https://github.com/kvcache-ai/ktransformers.git
+cd ktransformers
+git submodule update --init
+
+# Install dependencies
+bash install.sh --dev xpu
+```
+
+## Running DeepSeek-R1 Models
+
+### Configuration for 16B VRAM GPUs
+Use our optimized configuration for constrained VRAM:
+
+```bash
+export SYCL_CACHE_PERSISTENT=1
+export ONEAPI_DEVICE_SELECTOR=level_zero:0
+export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+
+python ktransformers/local_chat.py \
+ --model_path deepseek-ai/DeepSeek-R1 \
+ --gguf_path \
+ --optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \
+ --cpu_infer \
+ --device xpu \
+ --max_new_tokens 200
+```
+
+## Known Limitations
+- Serving function is not supported on Intel GPU platform for now
+
+## Troubleshooting
+1. Best Known Config (BKC) to obtain best performance
+
+To obtain best performance on Intel GPU platform, we recommend to lock GPU frequency and set CPU to performance mode by below settings.
+```bash
+echo "performance" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
+echo 0 | sudo tee /sys/devices/system/cpu/cpu*/power/energy_perf_bias
+# 2400 is max frequency for Arc A770
+sudo xpu-smi config -d 0 -t 0 --frequencyrange 2400,2400
+# 2850 is max frequency for Arc B580
+# sudo xpu-smi config -d 0 -t 0 --frequencyrange 2850,2850
+```
+
+2. Runtime error like `xpu/sycl/TensorCompareKernels.cpp:163: xxx. Aborted (core dumped)`
+
+This error is mostly related to GPU driver. If you meet such error, you could update your `intel-level-zero-gpu` to `1.3.29735.27-914~22.04` (which is a verified version by us) by below command.
+```bash
+wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \
+sudo gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg
+echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy client" | \
+sudo tee /etc/apt/sources.list.d/intel-gpu-jammy.list
+sudo apt update
+# or sudo apt update --allow-insecure-repositories
+sudo apt install intel-level-zero-gpu=1.3.29735.27-914~22.04
+```
+
+3. `ImportError: cannot import name 'intel' from 'triton._C.libtriton'`
+
+Installing Triton causes pytorch-triton-xpu to stop working. You can resolve the issue with following command:
+```bash
+pip uninstall triton pytorch-triton-xpu
+# Reinstall correct version of pytorch-triton-xpu
+pip install pytorch-triton-xpu==3.3.0 --index-url https://download.pytorch.org/whl/xpu
+```
+
+4. `ValueError: Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.`
+
+Ensure you have permissions to access /dev/dri/renderD*. This typically requires your user to be in the render group:
+```bash
+sudo gpasswd -a ${USER} render
+newgrp render
+```
+
+## Additional Information
+To run KTransformers on XPU with Docker, please refer to [Docker_xpu.md](./Docker_xpu.md).
diff --git a/install-with-cache.sh b/install-with-cache.sh
new file mode 100755
index 0000000..cef4341
--- /dev/null
+++ b/install-with-cache.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+set -e
+
+# clear build dirs
+# rm -rf build
+# rm -rf *.egg-info
+# rm -rf csrc/build
+# rm -rf csrc/ktransformers_ext/build
+# rm -rf csrc/ktransformers_ext/cuda/build
+# rm -rf csrc/ktransformers_ext/cuda/dist
+# rm -rf csrc/ktransformers_ext/cuda/*.egg-info
+rm -rf ~/.ktransformers
+echo "Installing python dependencies from requirements.txt"
+pip install -r requirements-local_chat.txt
+pip install -r ktransformers/server/requirements.txt
+echo "Installing ktransformers"
+KTRANSFORMERS_FORCE_BUILD=TRUE USE_BALANCE_SERVE=1 pip install -v . --no-build-isolation
+pip install third_party/custom_flashinfer/ -v
+
+# SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
+# echo "Copying thirdparty libs to $SITE_PACKAGES"
+# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/
+# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython*
+
+
+echo "Installation completed successfully"
diff --git a/install.sh b/install.sh
index 06866c6..573826e 100644
--- a/install.sh
+++ b/install.sh
@@ -1,6 +1,20 @@
#!/bin/bash
set -e
+# default backend
+DEV="cuda"
+
+# parse --dev argument
+while [[ "$#" -gt 0 ]]; do
+ case $1 in
+ --dev) DEV="$2"; shift ;;
+ *) echo "Unknown parameter passed: $1"; exit 1 ;;
+ esac
+ shift
+done
+export DEV_BACKEND="$DEV"
+echo "Selected backend: $DEV_BACKEND"
+
# clear build dirs
rm -rf build
rm -rf *.egg-info
@@ -13,14 +27,17 @@ rm -rf ~/.ktransformers
echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
pip install -r ktransformers/server/requirements.txt
+
echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation
-pip install third_party/custom_flashinfer/
+if [[ "$DEV_BACKEND" == "cuda" ]]; then
+ echo "Installing custom_flashinfer for CUDA backend"
+ pip install third_party/custom_flashinfer/
+fi
# SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
# echo "Copying thirdparty libs to $SITE_PACKAGES"
# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/
# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython*
-
echo "Installation completed successfully"
\ No newline at end of file
diff --git a/ktransformers/__init__.py b/ktransformers/__init__.py
index fa10c92..1fa9717 100644
--- a/ktransformers/__init__.py
+++ b/ktransformers/__init__.py
@@ -8,4 +8,4 @@ Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2025-02-15 03:53:02
'''
-__version__ = "0.3"
+__version__ = "0.3.1"
diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml
index c4f6186..ed1713b 100644
--- a/ktransformers/configs/config.yaml
+++ b/ktransformers/configs/config.yaml
@@ -29,7 +29,7 @@ model:
gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF
device: cuda:0
- cache_lens: 8192
+ cache_lens: 16384
max_new_tokens: 500
web:
mount: False
diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py
index 928de48..75e12fb 100644
--- a/ktransformers/local_chat.py
+++ b/ktransformers/local_chat.py
@@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
-from ktransformers.util.utils import prefill_and_generate, get_compute_capability
+from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
@@ -63,18 +63,24 @@ def local_chat(
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
- chunk_size: int = 8192
+ chunk_size: int = 8192,
+ device: str = "cuda"
):
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
+ Config().chunk_size = chunk_size
+ if torch.xpu.is_available():
+ use_cuda_graph = False
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
+ elif xpu_fp16_model(config):
+ torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
@@ -89,11 +95,16 @@ def local_chat(
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
-
+ if torch.xpu.is_available():
+ config._attn_implementation = "eager"
model = custom_models[config.architectures[0]](config)
else:
+ if torch.xpu.is_available():
+ attn_implementation = "eager"
+ else:
+ attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_config(
- config, trust_remote_code=True, attn_implementation="flash_attention_2"
+ config, trust_remote_code=True, attn_implementation=attn_implementation
)
if optimize_config_path is None:
@@ -109,7 +120,7 @@ def local_chat(
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
- optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
+ optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, default_device=device)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
@@ -172,12 +183,12 @@ def local_chat(
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate(
- model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
+ model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
)
else:
generated = prefill_and_generate(
- model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
+ model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
)
diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py
index e4a271e..350af73 100644
--- a/ktransformers/models/custom_cache.py
+++ b/ktransformers/models/custom_cache.py
@@ -66,7 +66,7 @@ class StaticCache(transformers.StaticCache):
self.page_table_list = []
for idx in range(config.num_hidden_layers):
if isinstance(device, dict):
- target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
+ target_device = device[f"model.layers.{idx}.self_attn"]["generate_device"]
else:
target_device = device
@@ -91,7 +91,7 @@ class StaticCache(transformers.StaticCache):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
if isinstance(device, dict):
- target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
+ target_device = device[f"model.layers.{idx}.self_attn"]["generate_device"]
else:
target_device = device
@@ -213,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self.v_caches = []
- def load(self, inference_context: "sched_ext.InferenceContext"):
+ def load(self, inference_context: "sched_ext.InferenceContext"):
for i in range(self.config.num_hidden_layers):
self.k_caches.append(
@@ -293,7 +293,7 @@ class KGQACache(nn.Module):
self.v_caches = []
- def load(self, inference_context: sched_ext.InferenceContext):
+ def load(self, inference_context: "sched_ext.InferenceContext"):
print(self.config.num_hidden_layers)
for i in range(self.config.num_hidden_layers):
self.k_caches.append(
diff --git a/ktransformers/models/custom_modeling_qwen2_moe.py b/ktransformers/models/custom_modeling_qwen2_moe.py
index 5740c14..1c84cbf 100644
--- a/ktransformers/models/custom_modeling_qwen2_moe.py
+++ b/ktransformers/models/custom_modeling_qwen2_moe.py
@@ -39,7 +39,7 @@ class KQwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
self.cache = cache
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.attn = [None] * 10
+ self.attn = [None] * 100
def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
diff --git a/ktransformers/models/custom_modeling_qwen3_moe.py b/ktransformers/models/custom_modeling_qwen3_moe.py
index 1cb8c46..32b9797 100644
--- a/ktransformers/models/custom_modeling_qwen3_moe.py
+++ b/ktransformers/models/custom_modeling_qwen3_moe.py
@@ -39,7 +39,7 @@ class KQwen3MoeForCausalLM(Qwen3MoePreTrainedModel):
self.cache = cache
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.attn = [None] * 10
+ self.attn = [None] * 100
def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py
index e14a521..f6845ec 100644
--- a/ktransformers/models/modeling_deepseek.py
+++ b/ktransformers/models/modeling_deepseek.py
@@ -107,6 +107,7 @@ class DeepseekV2RMSNorm(nn.Module):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
+ self.hidden_size = hidden_size
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
diff --git a/ktransformers/models/modeling_deepseek_v3.py b/ktransformers/models/modeling_deepseek_v3.py
index f296d9f..3a59d77 100644
--- a/ktransformers/models/modeling_deepseek_v3.py
+++ b/ktransformers/models/modeling_deepseek_v3.py
@@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
@@ -1598,7 +1599,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return causal_mask
-class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
+class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py
index 75d1a6e..85d6556 100644
--- a/ktransformers/operators/RoPE.py
+++ b/ktransformers/operators/RoPE.py
@@ -23,7 +23,7 @@ from ktransformers.models.modeling_deepseek import (
yarn_find_correction_range
)
from ktransformers.operators.base_operator import BaseInjectedModule
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader
from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig
import torch
diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py
index 0f5f9ae..9dfdbdc 100644
--- a/ktransformers/operators/attention.py
+++ b/ktransformers/operators/attention.py
@@ -13,9 +13,10 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
+from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader
from ktransformers.util.utils import get_compute_capability
import logging
from transformers.configuration_utils import PretrainedConfig
@@ -587,6 +588,100 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
return attn_output, None, past_key_value
+ def forward_xpu(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.q_lora_rank is None:
+ q = self.q_proj(hidden_states)
+ else:
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+ query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ compressed_kv, k_pe = torch.split(
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+ )
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
+ kv = (
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
+ .transpose(1, 2)
+ )
+
+ k_nope, value_states = torch.split(
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
+ )
+ kv_seq_len = value_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ position_embeddings = kwargs.get("position_embeddings", None)
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ key_states = torch.cat(
+ [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
+ dim=-1
+ )
+ from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
+ rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :],
+ key_states[:, :, :, self.qk_nope_head_dim:],
+ cos, sin, True)
+ else:
+ q_nope, q_pe = torch.split(
+ query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+ cos, sin = self.rotary_emb(q_pe, position_ids)
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
+
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(
+ key_states.half(), value_states.half(), self.layer_idx, cache_kwargs
+ )
+
+ attn_weights = None
+ from ipex_llm.transformers.models.common import scaled_dot_product_attention
+ attn_output = scaled_dot_product_attention(
+ query_states.half(), key_states, value_states,
+ attention_mask.half(), q_len == kv_seq_len, self.softmax_scale
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
+ attn_output = self.o_proj(attn_output).to(hidden_states.dtype)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -598,7 +693,21 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
+ if torch.xpu.is_available():
+ return self.forward_xpu(
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ output_attentions,
+ use_cache,
+ cache_position,
+ **kwargs,
+ )
+ elif (os.name == 'nt'
+ or get_compute_capability() < 8
+ or hidden_states.device.type == 'cpu'
+ or device_manager.gpu_vendor != GPUVendor.NVIDIA):
return self.forward_windows(
hidden_states,
attention_mask,
@@ -762,3 +871,75 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights = None
return attn_output, attn_weights, past_key_value
+
+
+class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "xpu",
+ generate_device: str = "xpu",
+ chunck_size: int = 1000,
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.config,
+ orig_module.layer_idx)
+ self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
+ assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
+ assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids: Optional[torch.Tensor],
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ bsz, q_len, _ = hidden_states.size()
+ input_dtype = hidden_states.dtype
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ if not hasattr(self, 'qkv_proj'):
+ from ipex_llm.transformers.models.common import merge_quantized_qkv
+ merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module)
+
+ qkv = self.qkv_proj(hidden_states)
+ qkv = qkv.view(bsz, q_len, -1, self.head_dim)
+ qkv = qkv.transpose(1, 2)
+ query_states, key_states, value_states = qkv.split([self.config.num_attention_heads,
+ self.config.num_key_value_heads,
+ self.config.num_key_value_heads], dim=1)
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ if position_embeddings is None:
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ cos, sin = position_embeddings
+
+ from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
+ rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states.half(), value_states.half(),
+ self.layer_idx, cache_kwargs)
+
+ attn_weights = None
+ from ipex_llm.transformers.models.common import scaled_dot_product_attention
+ attn_output = scaled_dot_product_attention(
+ query_states.half(), key_states, value_states,
+ attention_mask.half(), q_len == key_states.size(2), self.scaling
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output).to(input_dtype)
+ return attn_output, attn_weights
diff --git a/ktransformers/operators/balance_serve_attention.py b/ktransformers/operators/balance_serve_attention.py
index a785413..51695f3 100644
--- a/ktransformers/operators/balance_serve_attention.py
+++ b/ktransformers/operators/balance_serve_attention.py
@@ -11,7 +11,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader
import logging
from transformers.configuration_utils import PretrainedConfig
from flashinfer import BatchMLAPagedAttentionWrapper
@@ -288,3 +288,170 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)
return attn_output
+
+
+class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "cuda",
+ generate_device: str = "cuda",
+ chunck_size: int = 1000,
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.config,
+ orig_module.layer_idx)
+ self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
+
+
+ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
+ kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
+ q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
+ out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
+ self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
+ bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
+ self.q_absorb.weight.data = q_absorb
+ self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
+ bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
+ self.out_absorb.weight.data = out_absorb
+ #del self.orig_module.kv_b_proj
+ q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
+ out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
+ return q_absorb, out_absorb
+
+
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ kv_cache: KDeepSeekV3Cache,
+ position_ids: torch.Tensor,
+ wrapper: None,
+ num_tokens_tensors: torch.Tensor,
+ page_idx: torch.Tensor,
+ page_offset: torch.Tensor,
+ attention_masks: Optional[list[torch.Tensor]] = None,
+ q_indptr: Optional[torch.Tensor] = None,
+ kv_indices: Optional[torch.Tensor] = None,
+ kv_indptr: Optional[torch.Tensor] = None,
+ bsz_tensors: Optional[torch.Tensor] = None,
+ last_page_len: Optional[torch.Tensor] = None,
+ ):
+ # range bsz_tensors
+ final_attention_output = torch.tensor([], device=hidden_states.device)
+ for i in range(bsz_tensors[0]):
+ batch_num_tokens_tensors = q_indptr[i+1] - q_indptr[i]
+ batch_last_page_len = last_page_len[i]
+ # kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe
+ batch_page_idx = page_idx[q_indptr[i]:q_indptr[i+1]]
+ batch_page_offset = page_offset[q_indptr[i]:q_indptr[i+1]]
+ # kv_page_nums is the number of pages for the current batch
+ kv_page_nums = kv_indptr[i+1] - kv_indptr[i]
+ # kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)
+ kv_total_len = kv_page_nums * kv_cache.page_size
+ if batch_last_page_len is not None:
+ kv_total_len = kv_total_len - (kv_cache.page_size - batch_last_page_len)
+ # print(f"kv_total_len's shape {kv_total_len.shape}")
+ # kv_index is the index of the kv cache pages for the current batch
+ kv_index = kv_indices[kv_indptr[i]:kv_indptr[i+1]]
+ # we can index [kv_index, page_offset_indices] to get the kv cache for the current batch
+ # from q_indptr[i] to q_indptr[i+1] is the range of the current batch
+ batch_hidden_states = hidden_states[q_indptr[i]:q_indptr[i+1]]
+ batch_position_ids = position_ids[q_indptr[i]:q_indptr[i+1]]
+ q_len, _ = batch_hidden_states.size()
+ # print("q_len -> ", q_len)
+
+ if self.q_lora_rank is None:
+ q = self.q_proj(batch_hidden_states, batch_num_tokens_tensors)
+ else:
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(batch_hidden_states, batch_num_tokens_tensors), batch_num_tokens_tensors), batch_num_tokens_tensors)
+ # for v3, bsz, q_len, num_heads(128), qk_head_dim(192=128(nope)+64(rope))
+ q = q.view(q_len, self.num_heads, self.q_head_dim)
+ # q_nope is [q_len, num_heads(128), qk_nope_head_dim(128)]
+ # q_pe is [q_len, num_heads(128), qk_rope_head_dim(64)]
+ q_nope, q_pe = torch.split(
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+ # compressed_kv is [q_len, kv_lora_rank(512) + rope(64)]
+ compressed_kv = self.kv_a_proj_with_mqa(batch_hidden_states, batch_num_tokens_tensors)
+ # compressed_kv is [q_len, kv_lora_rank(512)], k_pe is [q_len, rope(64)]
+ compressed_kv, k_pe = torch.split(
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+ )
+ compressed_kv = compressed_kv.contiguous()
+ compressed_kv = self.kv_a_layernorm(compressed_kv, batch_num_tokens_tensors)
+ # k_pe is [q_len, 1, qk_rope_head_dim(64)]
+ k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
+ # compressed_kv is [q_len, 1, kv_lora_rank(512)]
+ compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
+
+ cos, sin = self.rotary_emb(q_pe, batch_position_ids.unsqueeze(0))
+ # print(f"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}")
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
+ q_pe = q_pe.squeeze(0)
+ # q_pe is [num_heads(128), q_len, qk_rope_head_dim(64)]
+ q_pe.transpose_(0, 1)
+ if kv_cache is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "page_idx": batch_page_idx, "page_offset": batch_page_offset} # Specific to RoPE models
+ compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, batch_page_idx, batch_page_offset, cache_kwargs)
+ compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
+ k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
+ # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]
+ # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim
+ q_absorb, out_absorb = self.get_absorbed()
+ # q_nope is [num_heads(128), q_len, qk_nope_head_dim(128)]
+ q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
+ # q_nope is [num_heads(128), q_len, kv_lora_rank(512)]
+ q_nope = torch.matmul(q_nope, q_absorb) # batched MM
+
+ # # q_nope is [q_len, num_heads(128), kv_lora_rank(512)]
+ # q_nope = q_nope.transpose(0, 1)
+
+ # we need to index out the compressed_kv and k_pe for the current batch
+ batch_compressed_kv = None
+ batch_k_pe = None
+ for page_index in kv_index:
+ if kv_total_len > kv_cache.page_size:
+ tmp_compressed_kv = compressed_kv[page_index, 0:kv_cache.page_size, :]
+ tmp_k_pe = k_pe[page_index, 0:kv_cache.page_size, :]
+ if batch_compressed_kv is None or batch_k_pe is None:
+ batch_compressed_kv = tmp_compressed_kv
+ batch_k_pe = tmp_k_pe
+ else:
+ batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
+ batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
+ kv_total_len -= kv_cache.page_size
+ else:
+ tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]
+ tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]
+ if batch_compressed_kv is None or batch_k_pe is None:
+ batch_compressed_kv = tmp_compressed_kv
+ batch_k_pe = tmp_k_pe
+ else:
+ batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
+ batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
+ break
+ # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]
+ # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]
+ attention_weights = (torch.matmul(q_pe,batch_k_pe.mT) + torch.matmul(q_nope, batch_compressed_kv.mT)) * self.softmax_scale
+ # attention_weights is [num_heads(128), q_len, k_len]
+
+ # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(q_len,-1,-1).transpose(0,1)
+
+ # attention_masks[i] is [q_len, k_len]
+
+ attention_weights = (attention_weights + attention_masks[i])
+ # attention_weights shape is [num_heads(128), q_len, k_len]
+ attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=torch.float32).to(q_pe.dtype)
+ attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),q_len, lora_rank(512)]
+ # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]
+ out_absorb = out_absorb.transpose(1,2)
+ # q for q_len, n for num_heads, h for v_head_dim, v for kv_lora_rank
+ attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), q_len, v_head_dim(128)]
+ attn_output = attn_output.transpose(0, 1) # [q_len, num_heads(128), v_head_dim(128)]
+ attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
+ attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)
+ final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
+ return final_attention_output
\ No newline at end of file
diff --git a/ktransformers/operators/base_operator.py b/ktransformers/operators/base_operator.py
index 0fa2efd..5e49709 100644
--- a/ktransformers/operators/base_operator.py
+++ b/ktransformers/operators/base_operator.py
@@ -6,7 +6,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
from typing import Any
from torch import nn, Tensor
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
import ktransformers.util.utils as utils
class BaseInjectedModule(nn.Module):
diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py
index 34f0af0..7a40168 100644
--- a/ktransformers/operators/experts.py
+++ b/ktransformers/operators/experts.py
@@ -26,7 +26,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import cpuinfer_ext
from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes
-from ktransformers.util.custom_gguf import GGMLQuantizationType, GGUFLoader
+from ktransformers.util.custom_gguf import GGMLQuantizationType
+from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader, ModelLoader
from ktransformers.util.utils import InferenceState
from ktransformers.server.config.config import Config
from transformers.activations import ACT2FN
@@ -39,8 +40,21 @@ from ktransformers.operators.cpuinfer import CPUInfer
def deduplicate_and_sort(lst):
return sorted(set(lst))
+def generate_cuda_graphs(chunk_size: int) -> list:
+ assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024"
+ base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]
+
+ if chunk_size <= 1024:
+ return deduplicate_and_sort(base_list)
+
+ multiples = [i for i in range(1024, chunk_size + 1, 1024)]
+
+ return deduplicate_and_sort(base_list + multiples)
#cuda_graphs = [Config().chunk_size]
-cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
+if torch.cuda.is_available():
+ cuda_graphs = generate_cuda_graphs(Config().chunk_size)
+else:
+ cuda_graphs = 1
# class Base(BaseInjectedModule, ABC):
class KExpertsBase(ABC):
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
@@ -77,7 +91,7 @@ class KExpertsBase(ABC):
down_type = None
for key in keys:
- if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
+ if self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"):
targets = [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight" ]
tensors = self.load_multi(key, targets, device=device)
gate = tensors[".ffn_gate_exps.weight"]
@@ -86,7 +100,7 @@ class KExpertsBase(ABC):
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
- elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info:
+ elif self.gguf_loader.has_tensor(key + ".ffn_down.0.weight"):
# for supporting Mixtral-8x7B-Instuct
gate = []
up = []
@@ -166,6 +180,11 @@ class KExpertsCPU(KExpertsBase):
n_routed_experts = self.n_routed_experts
self.cpu_infer = KExpertsCPU.CPU_INFER
# n_routed_experts = len(self.orig_module)
+ model_dtype = torch.get_default_dtype()
+ if torch.xpu.is_available() and model_dtype == torch.float16:
+ hidden_type = 1 # fp16
+ else:
+ hidden_type = 30 # bf16
if self.backend == "llamafile":
moe_config = MOEConfig(
n_routed_experts,
@@ -181,7 +200,7 @@ class KExpertsCPU(KExpertsBase):
self.gate_type,
self.up_type,
self.down_type,
- 30, # TODO: get from model.dtype
+ hidden_type, # TODO: get from model.dtype
)
self.moe = MOE(moe_config)
elif self.backend == "AMXBF16":
@@ -194,7 +213,7 @@ class KExpertsCPU(KExpertsBase):
self.config.num_experts_per_tok,
self.config.hidden_size,
self.config.moe_intermediate_size,
- 25600,
+ max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
gate_ptr,
up_ptr,
down_ptr,
@@ -212,7 +231,7 @@ class KExpertsCPU(KExpertsBase):
self.config.num_experts_per_tok,
self.config.hidden_size,
self.config.moe_intermediate_size,
- 25600,
+ max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
gate_ptr,
up_ptr,
down_ptr,
@@ -241,8 +260,12 @@ class KExpertsCPU(KExpertsBase):
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
- KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
- KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
+ if torch.xpu.is_available():
+ KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype)
+ KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
+ else:
+ KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
+ KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
if bsz_tensor is None:
@@ -274,9 +297,9 @@ class KExpertsCPU(KExpertsBase):
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
# generate, capture and run cuda graph
# print(expert_ids)
- if bsz_tensor is None:
+ if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1):
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
- if torch.cuda.is_current_stream_capturing():
+ if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
@@ -296,6 +319,15 @@ class KExpertsCPU(KExpertsBase):
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
+ elif input_tensor.size(0)==1 and torch.xpu.is_available():
+ KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True)
+ KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True)
+ KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True)
+ # KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True)
+ self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
+ self.cpu_infer.sync()
+ KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
+ return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1)
else:
input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu()
@@ -325,14 +357,19 @@ class KExpertsCPU(KExpertsBase):
down_type = None
for key in keys:
- if self.gguf_loader.safetensor_loader is not None:
- # using a temp ugly way to temprary load the tensor
- gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
- up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
- down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
- gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
- up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
- down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
+ if isinstance(self.gguf_loader, SafeTensorLoader):
+ res = self.gguf_loader.load_experts(key)
+ return {key: res}
+ elif self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"):
+ gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
+ up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
+ down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
+ # gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
+ # up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
+ # down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
+ gate_type = self.gguf_loader.get_ggml_type(key + ".ffn_gate_exps.weight")
+ up_type = self.gguf_loader.get_ggml_type(key + ".ffn_up_exps.weight")
+ down_type = self.gguf_loader.get_ggml_type(key + ".ffn_down_exps.weight")
elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
@@ -356,9 +393,9 @@ class KExpertsCPU(KExpertsBase):
gate = np.stack(gate)
up = np.stack(up)
down = np.stack(down)
- gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"]
- up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"]
- down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"]
+ gate_type = self.gguf_loader.get_ggml_type(key + ".ffn_gate.0.weight")
+ up_type = self.gguf_loader.get_ggml_type(key + ".ffn_up.0.weight")
+ down_type = self.gguf_loader.get_ggml_type(key + ".ffn_down.0.weight")
else:
raise ValueError(f"Experts {key} not found in gguf_loader")
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
@@ -445,7 +482,7 @@ class KExpertsMarlin(KExpertsBase):
down = None
for key in keys:
- if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
+ if self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"):
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
@@ -806,7 +843,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
+ if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@@ -906,7 +943,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
- if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
+ if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@@ -1106,7 +1143,7 @@ class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
# only for generate phase
- if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
+ if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
@@ -1288,7 +1325,7 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
- if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
+ if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
@@ -1384,24 +1421,33 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return final_out
class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
- def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
+ def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- router_logits = self.gate(hidden_states, bsz_tensor)
+ if bsz_tensor is None:
+ router_logits = self.gate(hidden_states)
+ else:
+ router_logits = self.gate(hidden_states, bsz_tensor)
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
- if self.norm_topk_prob:
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ if router_logits.device.type == "xpu":
+ from ipex_llm.transformers.models.common import moe_softmax_topk
+ selected_experts, routing_weights = moe_softmax_topk(
+ router_logits.half(), self.top_k, self.norm_topk_prob
+ )
+ else:
+ routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ if self.norm_topk_prob:
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
- if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
+ if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
diff --git a/ktransformers/operators/flashinfer_batch_prefill_wrapper.py b/ktransformers/operators/flashinfer_batch_prefill_wrapper.py
index e934654..287affb 100644
--- a/ktransformers/operators/flashinfer_batch_prefill_wrapper.py
+++ b/ktransformers/operators/flashinfer_batch_prefill_wrapper.py
@@ -40,7 +40,7 @@ class flashInferAttn():
self.kv_layout = kv_layout
self.use_cuda_graph = use_cuda_graph
if flashInferAttn.float_workspace_buffer is None:
- flashInferAttn.float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=device)
+ flashInferAttn.float_workspace_buffer = torch.empty(max_batch_token * 1024 * 1024, dtype=torch.uint8, device=device)
self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py
index cf5799e..f5f96c1 100644
--- a/ktransformers/operators/gate.py
+++ b/ktransformers/operators/gate.py
@@ -6,7 +6,7 @@ import os
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.linear import KTransformersLinear
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader, ModelLoader, SafeTensorLoader
from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod
@@ -55,24 +55,20 @@ class KMoEGateBase(ABC):
down_type = None
for key in keys:
- key = ".".join(key.split(".")[:-1])
- if self.gguf_loader.safetensor_loader is not None:
- targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
- weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
- e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
- weight_type = weight.dtype
- e_score_correction_bias_type = e_score_correction_bias.dtype
- res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
- elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
- targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
+ # key = ".".join(key.split(".")[:-1])
+ if isinstance(self.gguf_loader, SafeTensorLoader):
+ res = self.gguf_loader.load_gate(key, device=device)
+ elif self.gguf_loader.has_tensor(key+".weight"):
+ # targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
+ targets = [".weight", ".e_score_correction_bias"]
tensors = self.load_multi(key, targets, device=device)
- weight = tensors[".ffn_gate_inp.weight"]
- e_score_correction_bias = tensors[".exp_probs_b.bias"]
- weight_type = self.gguf_loader.tensor_info[key + ".ffn_gate_inp.weight"]["ggml_type"]
- e_score_correction_bias_type = self.gguf_loader.tensor_info[key + ".exp_probs_b.bias"]["ggml_type"]
+ weight = tensors[".weight"]
+ e_score_correction_bias = tensors[".e_score_correction_bias"]
+ # weight_type = self.gguf_loader.tensor_info[key + ".weight"]["ggml_type"]
+ res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias}
else:
raise ValueError(f"Experts {key} not found in gguf_loader")
- res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
+
return res
def load_multi(self, key: str, keys: list[str], device: str = "cpu"):
@@ -106,8 +102,6 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
- self.weight_type = w["weight_type"]
- self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
@@ -175,8 +169,6 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
- self.weight_type = w["weight_type"]
- self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
@@ -190,4 +182,34 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
- self.e_score_correction_bias = None
\ No newline at end of file
+ self.e_score_correction_bias = None
+
+
+class KMoEGateIPEXLLM(KMoEGate):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module = None,
+ generate_device: str = "xpu",
+ prefill_device: str = "xpu",
+ **kwargs,
+ ):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
+ KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
+ self.generate_device = generate_device
+ self.prefill_device = prefill_device
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ x = hidden_states.view(-1, hidden_states.size(-1))
+ logits = torch.nn.functional.linear(
+ x.type(torch.float32), self.orig_module.weight.type(torch.float32), None
+ )
+ scores = logits.sigmoid()
+
+ from ipex_llm.transformers.models.common import moe_group_topk
+ topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,
+ self.n_group, self.topk_group, self.top_k,
+ self.norm_topk_prob, self.routed_scaling_factor)
+ return topk_idx, topk_weight.to(x.dtype)
\ No newline at end of file
diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py
index 62c5cba..796592c 100644
--- a/ktransformers/operators/layernorm.py
+++ b/ktransformers/operators/layernorm.py
@@ -29,11 +29,12 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
from ktransformers.operators.base_operator import BaseInjectedModule
-from ktransformers.util.custom_gguf import GGUFLoader
-from flashinfer.norm import (
- fused_add_rmsnorm,
- rmsnorm,
-)
+from ktransformers.util.custom_loader import GGUFLoader
+if not torch.xpu.is_available():
+ from flashinfer.norm import (
+ fused_add_rmsnorm,
+ rmsnorm,
+ )
logger = logging.getLogger(__name__)
@@ -163,3 +164,62 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
+
+class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "cuda",
+ generate_device: str = "cuda",
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.hidden_size,
+ orig_module.variance_epsilon)
+
+ def forward(
+ self,
+ x,
+ batch_size_tensor: torch.Tensor = None,
+ residual: Optional[torch.Tensor] = None,
+ )-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ if residual is not None:
+ x = x + residual
+ residual = x
+ # range batch_size_tensor for x
+ input_dtype = x.dtype
+ x = x.to(torch.float32)
+ variance = x.pow(2).mean(-1, keepdim=True)
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
+ if residual is not None:
+ return self.weight * x.to(input_dtype), residual
+ return self.weight * x.to(input_dtype)
+
+
+class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "xpu",
+ generate_device: str = "xpu",
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.weight.shape[0],
+ orig_module.variance_epsilon)
+ self.eps = orig_module.variance_epsilon
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ from ipex_llm.transformers.models.common import rms_norm_forward
+ if x.dtype not in [torch.float32, torch.float16]:
+ output = rms_norm_forward(self, x.float())
+ else:
+ output = rms_norm_forward(self, x)
+ return output.to(x.dtype)
+
+ def load(self):
+ BaseInjectedModule.load(self)
+ if self.weight.dtype not in [torch.float32, torch.float16]:
+ self.weight = self.weight.float()
\ No newline at end of file
diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py
index 293826e..654c9f9 100644
--- a/ktransformers/operators/linear.py
+++ b/ktransformers/operators/linear.py
@@ -14,18 +14,20 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import ctypes
import torch
from torch import Tensor, nn
-import KTransformersOps
-import vLLMMarlin
-from ktransformers.util.custom_gguf import GGUFLoader
+if not torch.xpu.is_available():
+ import KTransformersOps
+ import vLLMMarlin
+from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
from ktransformers.util.utils import InferenceState
-from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
- MarlinWorkspace,
- marlin_quantize,
- GPTQ_MARLIN_MIN_THREAD_N,
- GPTQ_MARLIN_MIN_THREAD_K,
- GPTQ_MARLIN_MAX_PARALLEL,
- vllm_marlin_quantize
-)
+if not torch.xpu.is_available():
+ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
+ MarlinWorkspace,
+ marlin_quantize,
+ GPTQ_MARLIN_MIN_THREAD_N,
+ GPTQ_MARLIN_MIN_THREAD_K,
+ GPTQ_MARLIN_MAX_PARALLEL,
+ vllm_marlin_quantize
+ )
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
@@ -83,15 +85,15 @@ class KLinearBase(ABC):
keys = [self.key]
for key in keys:
- if self.gguf_loader.safetensor_loader is not None:
+ if isinstance(self.gguf_loader, SafeTensorLoader):
# using safetensor_loader
- tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
- if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map:
- weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
+ tensor = self.gguf_loader.load_tensor(key+'.weight')
+ if self.gguf_loader.has_tensor(key+'.weight_scale_inv'):
+ weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
return nn.Parameter(tensor)
- elif key + ".weight" in self.gguf_loader.tensor_file_map:
+ elif self.gguf_loader.has_tensor(key + ".weight") or "kv_b_proj" in key:
if key + ".bias" in self.gguf_loader.tensor_file_map:
tensors = self.load_multi(key, ["weight", "bias"], device=device)
tensor = tensors["weight"]
@@ -99,6 +101,19 @@ class KLinearBase(ABC):
# self.qtype = GGML_TYPE_QTYPE_MAP[tensorinfo[key + ".weight"]["ggml_type"]]
# print(torch.isinf(tensor).any(), torch.isinf(bias).any())
return nn.Parameter(tensor), nn.Parameter(bias)
+ elif "kv_b_proj" in key and not self.gguf_loader.has_tensor(key + ".weight"):
+ attn_k_b_tensors = self.load_multi(key.replace("self_attn.kv_b_proj", "attn_k_b"), ["weight"], device=device)
+ attn_k_b = attn_k_b_tensors["weight"]
+ del attn_k_b_tensors
+ attn_k_b = attn_k_b.transpose(1, 2).contiguous()
+ attn_v_b_tensors = self.load_multi(key.replace("self_attn.kv_b_proj", "attn_v_b"), ["weight"], device=device)
+ attn_v_b = attn_v_b_tensors["weight"]
+ del attn_v_b_tensors
+ kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)
+ kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()
+ del attn_k_b
+ del attn_v_b
+ return nn.Parameter(kv_b_proj)
else:
tensors = self.load_multi(key, ["weight"], device=device)
tensor = tensors["weight"]
@@ -502,6 +517,9 @@ class VLinearMarlin(KLinearBase):
marlin_s = self.marlin_s.to(x.dtype)
sms = -1
+ # padding x.shape[0] to avoid CUDA illegal memory access error
+ x, orig_size_m = self._pad_input(x)
+
x = vLLMMarlin.gptq_marlin_gemm(
x,
self.marlin_q_w,
@@ -511,26 +529,15 @@ class VLinearMarlin(KLinearBase):
self.workspace.scratch,
self.num_bits,
bsz_tensor,
- # torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
x.shape[0],
self.n,
x.shape[-1],
sms,
self.is_k_full,
)
- # x = KTransformersOps.gptq_marlin_gemm(
- # x,
- # self.marlin_q_w,
- # marlin_s,
- # self.g_idx,
- # self.sort_indices,
- # self.workspace.scratch,
- # self.num_bits,
- # x.shape[0],
- # self.n,
- # x.shape[-1],
- # self.is_k_full,
- # )
+
+ x = x[:orig_size_m]
+
if self.has_bias:
x = x + self.bias
orig_shape[-1] = self.n
@@ -546,6 +553,27 @@ class VLinearMarlin(KLinearBase):
self.sort_indices = None
self.workspace = None
+ def _pad_input(self, x):
+
+ size_m = x.shape[0]
+ size_k = x.shape[1]
+
+ # size_m and align value depends on VLinearMarlin implementation
+ if size_m > 1024:
+ align = 1024
+ elif size_m > 64:
+ align = 64
+ else:
+ align = 1
+
+ padded_size_m = ((size_m + align - 1) // align) * align
+
+ if padded_size_m > size_m:
+ pad_len = padded_size_m - size_m
+ pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device)
+ x = torch.cat([x, pad_tensor], dim = 0).contiguous()
+ return x, size_m
+
class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
@@ -760,7 +788,7 @@ class KLinearCPUInfer(KLinearBase):
self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)
def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"):
- if self.key + ".weight" in self.gguf_loader.tensor_info:
+ if self.gguf_loader.has_tensor(self.key + ".weight"):
if self.key + ".bias" in self.gguf_loader.tensor_file_map:
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
@@ -778,6 +806,75 @@ class KLinearCPUInfer(KLinearBase):
if self.has_bias:
self.bias = None
+class KLinearIPEXLLM(KLinearBase):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module = None,
+ device: str = "xpu",
+ precision: str = "sym_int4",
+ **kwargs,
+ ):
+ super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
+ self.has_bias = False
+ self.dtype = torch.get_default_dtype()
+ self.weight = None
+ self.has_bias = False
+ self.precision = precision
+ self.qtype = None
+
+ def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
+ dtype = x.dtype
+ out_device = x.device
+ from ipex_llm.transformers.models.common import linear_forward
+ x = linear_forward(x.half(), self.weight, self.qtype, self.out_features)
+
+ if self.has_bias:
+ x = x + self.bias
+ x = x.to(dtype=dtype, device=out_device)
+ return x
+
+ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
+ if self.loaded: return
+ if device is None: device = self.device
+ assert device.lower()[:3] == "xpu", "IPEX-LLM quantized linear only supports XPU device"
+ if w is None: w = self.load_weight(device=device)
+
+ if isinstance(w, nn.Parameter):
+ try:
+ weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
+ except:
+ weight = w.to(dtype=self.dtype).T
+ self.has_bias = False
+ elif isinstance(w, tuple):
+ try:
+ weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
+ except:
+ weight = w[0].to(dtype=self.dtype).T
+ self.bias = w[1].to(dtype=self.dtype)
+ self.has_bias = True
+ else:
+ raise ValueError("Invalid weight type")
+ weight = weight.to("cpu").float().transpose(0, 1).contiguous()
+
+ if self.has_bias:
+ self.bias = self.bias.to(device)
+
+ # quantize linear weight
+ from ipex_llm.transformers.models.common import quantize_linear
+ paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision)
+ self.weight = paramsLowBit.to(device)
+ self.qtype = qtype
+ self.loaded = True
+
+ def unload(self):
+ if self.weight is not None:
+ self.weight = None
+ if self.has_bias:
+ self.bias = None
+
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
@@ -785,6 +882,7 @@ LINEAR_MAP = {
"VLinearMarlin": VLinearMarlin,
"KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
+ "KLinearIPEXLLM": KLinearIPEXLLM,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):
diff --git a/ktransformers/operators/mlp.py b/ktransformers/operators/mlp.py
index 02648b1..77d7d05 100644
--- a/ktransformers/operators/mlp.py
+++ b/ktransformers/operators/mlp.py
@@ -1,6 +1,6 @@
from ktransformers.operators.base_operator import BaseInjectedModule
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader
from transformers import PretrainedConfig
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py
index bbac29a..e136b57 100644
--- a/ktransformers/operators/models.py
+++ b/ktransformers/operators/models.py
@@ -58,7 +58,7 @@ from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState, get_compute_capability
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import (
LlamaDecoderLayer,
@@ -306,6 +306,12 @@ class KQwen2MoeModel(BaseInjectedModule):
hidden_states = inputs_embeds
+ # create position embeddings to be shared across the decoder layers
+ if torch.xpu.is_available() and inputs_embeds.device.type == "xpu":
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ else:
+ position_embeddings = None
+
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@@ -369,6 +375,7 @@ class KQwen2MoeModel(BaseInjectedModule):
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
+ position_embeddings=position_embeddings,
)
if per_layer_prefill_flag:
# print(f"to cpu")
@@ -376,8 +383,10 @@ class KQwen2MoeModel(BaseInjectedModule):
torch.cuda.empty_cache()
hidden_states = layer_outputs[0]
- if use_cache:
+ if use_cache and len(layer_outputs) > 1:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+ else:
+ next_decoder_cache = None
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -396,11 +405,14 @@ class KQwen2MoeModel(BaseInjectedModule):
next_cache = None
if use_cache:
- next_cache = (
- next_decoder_cache.to_legacy_cache()
- if use_legacy_cache
- else next_decoder_cache
- )
+ if next_decoder_cache is not None:
+ next_cache = (
+ next_decoder_cache.to_legacy_cache()
+ if use_legacy_cache
+ else next_decoder_cache
+ )
+ else:
+ next_cache = past_key_values
if not return_dict:
return tuple(
@@ -647,10 +659,20 @@ class KDeepseekV2Model(BaseInjectedModule):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
+ if inputs_embeds.device.type == "xpu" and position_ids is not None:
+ cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
+ position_ids)
+ position_embeddings = (cos, sin)
+ else:
+ position_embeddings = None
+
if per_layer_prefill_flag:
causal_mask = None
else:
- if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
+ if (os.name == 'nt'
+ or get_compute_capability() < 8
+ or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
+ or device_manager.gpu_vendor != GPUVendor.NVIDIA):
# print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
@@ -734,6 +756,7 @@ class KDeepseekV2Model(BaseInjectedModule):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
+ position_embeddings=position_embeddings,
)
t5 = time.time()
if per_layer_prefill_flag:
diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py
index 331e6cf..bbe08c8 100644
--- a/ktransformers/optimize/optimize.py
+++ b/ktransformers/optimize/optimize.py
@@ -12,7 +12,7 @@ from torch import nn
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
# from operators import BaseInjectedModule
-from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
+from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory
from ktransformers.util.utils import set_module, load_weights
import itertools
import copy
@@ -54,7 +54,7 @@ def del_meta(module:nn.Module):
def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"):
module_name = prefix[:-1]
- translated_name = translate_name_to_gguf(prefix)[:-1]
+ # translated_name = translate_name_to_gguf(prefix)[:-1]
#print("gen_optimize_config", prefix, module_name, translated_name)
recursive = True
for rule in rule_list:
@@ -76,7 +76,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
if "replace" in rule:
replace_meta = rule["replace"]
if module_name not in out_data:
- out_data[module_name]={"key": translated_name,
+ out_data[module_name]={"key": module_name,
"class": replace_meta["class"] if "class" in replace_meta else "default",
# "device": replace_meta["device"] if "device" in replace_meta else default_device,
"kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()}
@@ -91,7 +91,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
if module_name not in out_data:
out_data[module_name]= {
"class": "default",
- "key": translated_name,
+ "key": module_name,
"kwargs": {"generate_device": default_device,
"prefill_device": default_device}
}
@@ -103,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + "."
- gen_optimize_config(child, out_data, rule_list, child_prefix)
+ gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device)
def translate_model_config(model_config: PretrainedConfig):
@@ -123,12 +123,15 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
model_config = translate_model_config(model_config)
- gguf_loader=GGUFLoader(gguf_path)
+ weights_loader = ModelLoaderFactory.create_loader(gguf_path)
with torch.device("meta"):
- inject(module, optimize_config, model_config, gguf_loader)
+ inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
- load_weights(module.lm_head, gguf_loader, "lm_head.")
- load_weights(module, gguf_loader)
- module.gguf_loader = gguf_loader
+ load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
+ load_weights(module, weights_loader, device=default_device)
+ module.gguf_loader = weights_loader
del_meta(module)
- torch.cuda.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif torch.xpu.is_available():
+ torch.xpu.empty_cache()
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-gpu-cpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-gpu-cpu.yaml
new file mode 100644
index 0000000..3425add
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-gpu-cpu.yaml
@@ -0,0 +1,184 @@
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+
+# === Rotary Embedding Replacement ===
+
+# GPU 0: layers 0–9
+- match:
+ name: "^model\\.layers\\.(0|[1-9])\\."
+ class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbedding
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+# CPU: layers 10-29
+- match:
+ name: "^model\\.layers\\.([12][0-9])\\."
+ class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbedding
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+
+# === Linear Layers Replacement (excluding self_attn) ===
+
+# GPU 0: layers 0–9
+- match:
+ name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+ generate_op: "KLinearMarlin"
+ prefill_op: "KLinearTorch"
+# CPU: layers 10-29
+- match:
+ name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+ generate_op: "KLinearCPUInfer"
+ prefill_op: "KLinearTorch"
+ out_device: "cpu"
+
+# === MLP (MoE) Replacement ===
+
+# GPU 0: layers 0–9
+- match:
+ name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
+ class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+# CPU: layers 10-29
+- match:
+ name: "^model\\.layers\\.([12][0-9])\\.mlp$"
+ class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+
+# === MLP Gate Replacement ===
+
+# GPU 0: layers 0–9
+- match:
+ name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.gate$"
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGate
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+# CPU: layers 10-29
+- match:
+ name: "^model\\.layers\\.([12][0-9])\\.mlp\\.gate$"
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGate
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+
+# === MLP Experts Replacement ===
+
+# GPU 0: layers 0–9
+- match:
+ name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cuda:0"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cuda:0"
+ recursive: False # don't recursively inject submodules of this module
+# CPU: layers 10-29
+- match:
+ name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cpu"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cpu"
+ recursive: False # don't recursively inject submodules of this module
+
+# === Self-Attention Replacement ===
+
+# GPU 0: layers 0–9
+- match:
+ name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+# CPU: layers 10-29
+- match:
+ name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+
+# === Overall Model Replacement with Transfer Map ===
+
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+ transfer_map:
+ 10: "cpu"
+
+# === Default Catch-All for Other Modules ===#
+# GPU 0: layers 0–9
+- match:
+ name: "^model\\.layers\\.(0|[1-9])\\."
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+
+#lmm_head on GPU 0
+- match:
+ name: "^lm_head"
+ class: torch.nn.Linear
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+ generate_op: "KLinearMarlin"
+ prefill_op: "KLinearTorch"
+
+# CPU: layers 10-29
+- match:
+ name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml
new file mode 100644
index 0000000..670f6d5
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml
@@ -0,0 +1,91 @@
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+- match:
+ name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ generate_op: "KLinearFP8"
+ prefill_op: "KLinearTorch"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGate
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cuda"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cuda"
+ backend: "llamafile"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
+ replace:
+ class: ktransformers.operators.layernorm.RMSNorm
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
+ replace:
+ class: ktransformers.operators.mlp.kDeepseekV3MLP
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
+- match:
+ name: "^lm_head$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ generate_op: "VLinearMarlin"
+ prefill_op: "KLinearTorch"
\ No newline at end of file
diff --git a/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml
new file mode 100644
index 0000000..5de582f
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml
@@ -0,0 +1,64 @@
+- match:
+ class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbedding
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model\\.layers\\..*" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ generate_op: "KLinearIPEXLLM"
+ prefill_op: "KLinearIPEXLLM"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "xpu"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "xpu"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm
+ replace:
+ class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+ device: "xpu"
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
\ No newline at end of file
diff --git a/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml
new file mode 100644
index 0000000..c0e46c3
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml
@@ -0,0 +1,81 @@
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^lm_head$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ generate_op: "KLinearIPEXLLM"
+ prefill_op: "KLinearIPEXLLM"
+- match:
+ name: "^model\\.layers\\..*" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ generate_op: "KLinearIPEXLLM"
+ prefill_op: "KLinearIPEXLLM"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
+ replace:
+ class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGateIPEXLLM
+ kwargs:
+ generate_device: "xpu:0"
+ prefill_device: "xpu:0"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "xpu"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "xpu"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
\ No newline at end of file
diff --git a/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml
new file mode 100644
index 0000000..6bb4dae
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml
@@ -0,0 +1,80 @@
+- match:
+ name: "rotary_emb$"
+ replace:
+ class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^lm_head$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ generate_op: "KLinearIPEXLLM"
+ prefill_op: "KLinearIPEXLLM"
+- match:
+ name: "^model\\.layers\\.(?!.*mlp\\.gate).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ generate_op: "KLinearIPEXLLM"
+ prefill_op: "KLinearIPEXLLM"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
+ replace:
+ class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "xpu"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "xpu"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KQwen3MoeAttentionIPEXLLM
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KQwen2MoeModel"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+- match:
+ class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm
+ replace:
+ class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP
+ replace:
+ class: ktransformers.operators.mlp.KQwen2MoeMLP
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py
index 1210e14..748bd47 100644
--- a/ktransformers/server/args.py
+++ b/ktransformers/server/args.py
@@ -128,10 +128,7 @@ class ArgumentParser:
else:
args.model_dir = self.cfg.model_dir
args.model_path = self.cfg.model_path
- # set config from args
- for key, value in vars(args).items():
- if value is not None and hasattr(self.cfg, key):
- setattr(self.cfg, key, value)
+
# we add the name not match args individually
self.cfg.model_device = args.device
self.cfg.mount_web = args.web
@@ -140,10 +137,15 @@ class ArgumentParser:
self.cfg.user_force_think = args.force_think
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
- if args.architectures == "Qwen3MoeForCausalLM" or args.architectures == "Qwen2MoeForCausalLM" :
+ if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" :
args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim
+ args.architectures = model_config.architectures[0]
else:
args.gpu_memory_size = args.cache_lens*2*576*61
+ # set config from args
+ for key, value in vars(args).items():
+ if value is not None and hasattr(self.cfg, key):
+ setattr(self.cfg, key, value)
self.cfg.gpu_memory_size = args.gpu_memory_size
free_ports = get_free_ports(3, [args.port])
args.sched_port = free_ports[0]
diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py
index 2d89332..a385c9c 100644
--- a/ktransformers/server/backend/interfaces/balance_serve.py
+++ b/ktransformers/server/backend/interfaces/balance_serve.py
@@ -195,13 +195,13 @@ class Engine:
self.block_num = inference_context.k_cache[0].size(1)
+ self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
- self.model.init_wrapper(self.args.use_cuda_graph, self.device, 1024 ,args.max_batch_size, self.block_num) # TODO: 1024 is a magic number(max_batch_tokens)
+ self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
- self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
self.sampler = Sampler()
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py
index 6bde540..78cb73f 100644
--- a/ktransformers/server/backend/interfaces/transformers.py
+++ b/ktransformers/server/backend/interfaces/transformers.py
@@ -11,6 +11,14 @@ from transformers import (
StaticCache,
AutoModelForCausalLM,
BitsAndBytesConfig,
+ LogitsProcessorList,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ MinPLogitsWarper,
+ TypicalLogitsWarper,
+ EpsilonLogitsWarper,
+ EtaLogitsWarper,
)
from ktransformers.server.config.config import Config
@@ -206,6 +214,58 @@ class TransformersInterface(BackendInterfaceBase):
self.seq_length += 1
return self.streamer.put(new_tokens)
+ @staticmethod
+ def tf_logits_warper(generation_config):
+ """
+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
+ used for multinomial sampling.
+ """
+
+ # instantiate warpers list
+ warpers = LogitsProcessorList()
+
+ # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
+ # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
+ if generation_config.num_beams > 1:
+ if isinstance(generation_config._eos_token_tensor, list):
+ min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
+ elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
+ min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
+ else:
+ min_tokens_to_keep = 2
+ else:
+ min_tokens_to_keep = 1
+
+ # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
+ # all samplers can be found in `generation_utils_samplers.py`
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
+ warpers.append(TemperatureLogitsWarper(generation_config.temperature))
+ if generation_config.top_k is not None and generation_config.top_k != 0:
+ warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.top_p is not None and generation_config.top_p < 1.0:
+ warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.min_p is not None:
+ # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
+ warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
+ warpers.append(
+ TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
+ )
+ if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
+ warpers.append(
+ EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
+ )
+ if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
+ warpers.append(
+ EtaLogitsWarper(
+ epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
+ )
+ )
+ # `LogitNormalization` should always be the last logit processor, when present
+ if generation_config.renormalize_logits is True:
+ warpers.append(LogitNormalization())
+ return warpers
+
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
if temperature is None or temperature == 0:
temperature = self.model.generation_config.temperature
@@ -222,14 +282,8 @@ class TransformersInterface(BackendInterfaceBase):
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
self.inputs = inputs
- try: # transformers==4.43
- self.logits_warper = (
- self.model._get_logits_warper(generation_config, device=device)
- )
- except:
- self.logits_warper = (
- self.model._get_logits_warper(generation_config)
- )
+
+ self.logits_warper = self.tf_logits_warper(generation_config)
def logits_to_token(self, logits: torch.Tensor):
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
diff --git a/ktransformers/server/balance_serve/inference/forward_batch.py b/ktransformers/server/balance_serve/inference/forward_batch.py
index 7022d9e..26b4d3d 100644
--- a/ktransformers/server/balance_serve/inference/forward_batch.py
+++ b/ktransformers/server/balance_serve/inference/forward_batch.py
@@ -200,7 +200,7 @@ class ForwardBatchInput:
device=None,
tokens: torch.Tensor = None,
num_mini_batches: int = 1,
- max_seq_length: int = 1024, # TODO: add to yaml
+ max_seq_length: int = 4096, # TODO: add to yaml
prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config
prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size,
gen_prefill: bool = True,
@@ -223,12 +223,12 @@ class ForwardBatchInput:
decode_querys_info = []
for i in range(min(decode_batch_size, cuda_lens)):
- query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset)
+ query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, 256, page_size, device, is_prefill=False, offset=offset)
offset += max_seq_length // page_size
if tokens is not None:
query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens)
if decode_active_position is None:
- query_info.active_position = prefill_active_length
+ query_info.active_position = 255
else:
query_info.active_position = decode_active_position[i]
diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py
index 79b3053..55dfb6d 100644
--- a/ktransformers/server/balance_serve/inference/model_runner.py
+++ b/ktransformers/server/balance_serve/inference/model_runner.py
@@ -39,6 +39,17 @@ def pad_num_tokens(num_tokens):
def deduplicate_and_sort(lst):
return sorted(set(lst))
+def generate_cuda_graphs(chunk_size: int) -> list:
+ # 如果输入不符合要求,assert掉
+ assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024"
+ base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]
+
+ if chunk_size <= 1024:
+ return deduplicate_and_sort(base_list)
+
+ multiples = [i for i in range(1024, chunk_size + 1, 1024)]
+
+ return deduplicate_and_sort(base_list + multiples)
class ModelRunner:
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
@@ -56,7 +67,7 @@ class ModelRunner:
self.features_buf = None
self.output = None
self.graph_memory_pool = None
- self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
+ self.cuda_graphs = generate_cuda_graphs(Config().chunk_size)
self.use_cuda_graph = use_cuda_graph
self.model_time = 0
self.page_size = page_size
diff --git a/ktransformers/tests/dequant_gpu.py b/ktransformers/tests/dequant_gpu.py
index 0dd5272..3dbd794 100644
--- a/ktransformers/tests/dequant_gpu.py
+++ b/ktransformers/tests/dequant_gpu.py
@@ -7,7 +7,7 @@ sys.path.append(current_path+"/../..")
import numpy as np
# from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
# from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
-from ktransformers.util.custom_gguf import GGUFLoader
+from ktransformers.util.custom_loader import GGUFLoader
import torch
import KTransformersOps
torch.set_default_dtype(torch.bfloat16)
diff --git a/ktransformers/tests/dequant_gpu_t.py b/ktransformers/tests/dequant_gpu_t.py
index 4b2556d..06de4a0 100644
--- a/ktransformers/tests/dequant_gpu_t.py
+++ b/ktransformers/tests/dequant_gpu_t.py
@@ -9,7 +9,7 @@ from pycuda.compiler import SourceModule
import numpy as np
from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
-from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k
+from ktransformers.util.custom_loader import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k
import torch
import KTransformersOps
torch.set_default_dtype(torch.bfloat16)
diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py
index b45bf87..6f435b4 100644
--- a/ktransformers/tests/test_speed.py
+++ b/ktransformers/tests/test_speed.py
@@ -159,5 +159,7 @@ if __name__ == "__main__":
prompt = ktansformer_prompt1024
elif args.prompt_lens == 2048:
prompt = ktansformer_prompt1024 * 2
+ elif args.prompt_lens == 4096:
+ prompt = ktansformer_prompt1024 * 4
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py
index b3d98d3..5e4ffd6 100644
--- a/ktransformers/util/custom_gguf.py
+++ b/ktransformers/util/custom_gguf.py
@@ -24,8 +24,8 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
-import KTransformersOps
-from .custom_loader import SafeTensorLoader
+if not torch.xpu.is_available():
+ import KTransformersOps
import ctypes
import math
@@ -166,238 +166,6 @@ DATA_TYPES = {
"FP8": 13,
}
-class GGUFLoader:
- tensor_info: dict
- gguf_path: str
- tensor_file_map: dict # {tensor_name: tensor_file_path}
- gguf_file_meta: dict
- safetensor_loader: SafeTensorLoader
- def __init__(self, gguf_path: str):
- # Check dir exist
- if not os.path.exists(gguf_path):
- raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
- if os.path.isfile(gguf_path):
- gguf_path = os.path.dirname(gguf_path)
-
- self.safetensor_loader = None
-
- self.tensor_info = {}
- self.gguf_path = gguf_path
- self.tensor_file_map = {}
- self.file_data_map = {}
- self.gguf_file_meta = {}
- self.tensor_device_map = {}
-
- # I know this is ugly, but I don't want to change the original code too much
- # TODO: merge gguf load and other loads.
- safetensor_loader = SafeTensorLoader(gguf_path)
- if safetensor_loader.tensor_file_map:
- self.safetensor_loader = safetensor_loader
- return
- # Walk through all the .gguf files in the directory
- found_gguf = False
- for root, dirs, files in os.walk(gguf_path):
- for file in files:
- if file.endswith(".gguf"):
- found_gguf = True
- file_name = os.path.join(root, file)
- with open(file_name, "rb") as f:
- self.load_gguf(f)
- if file_name not in self.file_data_map:
- self.file_data_map[file_name] = np.memmap(file_name, mode = 'r')
- if not found_gguf:
- raise FileNotFoundError(f"Cannot find any .gguf files in: {gguf_path}")
-
- def load_gguf(self, f):
- f.seek(0)
- assert f.read(4) == b'GGUF'
- values = struct.unpack("torch.Tensor:
- t = self.tensor_info[name]
- if device.lower() == "cpu":
- print(f"loading expert {expert_id} of {name} with CPU")
- shape = t["shape"]
- ggml_type = t["ggml_type"]
- if ggml_type not in GGML_NAMES:
- raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
- ggml_name = GGML_NAMES[ggml_type]
-
- # TODO: experts may fused in quant block, split it
- assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant"
-
- blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name]
- block_size = GGML_BLOCK_SIZES[ggml_name]
- offset = expert_id * block_size * blocks_per_experts
- data = data[offset: offset + block_size * blocks_per_experts]
-
- if "cuda" in device.lower():
- values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
- else:
- values = GGML_DEQUANTIZE[ggml_name](data)
- values = torch.from_numpy(values.copy())
-
- if ggml_name == "BF16":
- values = values.view(torch.bfloat16)
- values = values.view(shape[-2::-1])
-
- return values
-
- def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
- t = self.tensor_info[name]
- if device.lower() == "cpu":
- print(f"loading {name} with CPU")
- if target_dtype == None:
- target_dtype = torch.get_default_dtype()
-
- shape = t["shape"]
- ggml_type = t["ggml_type"]
-
- if ggml_type not in GGML_NAMES:
- raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
-
- ggml_name = GGML_NAMES[ggml_type]
-
- data = self.get_mmap_tensor(name)
-
- block_size = GGML_BLOCK_SIZES[ggml_name]
- elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
- num_elements = int(np.prod(shape))
- num_blocks = num_elements // elements_per_block
-
- blocks_per_iter = 16384
- if num_blocks > blocks_per_iter: # dequant large tensor
- values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
- for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
- blocks_begin = i * blocks_per_iter
- blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
- if "cuda" in device.lower():
- cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
- else:
- cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
- cur_values = torch.from_numpy(cur_values.copy())
-
- cur_values = cur_values.view(-1, elements_per_block)
- if ggml_name == "BF16":
- cur_values = cur_values.view(torch.bfloat16)
- values[blocks_begin : blocks_end] = cur_values
- else:
- if "cuda" in device.lower():
- values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
- else:
- values = GGML_DEQUANTIZE[ggml_name](data)
- values = torch.from_numpy(values)
-
- if ggml_name == "BF16":
- values = values.view(torch.bfloat16)
-
-
- values = values.view(shape[::-1])
- if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
- n_head = self.gguf_file_meta['llama.attention.head_count']
- values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
- .swapaxes(1, 2)
- .reshape(values.shape))
- elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
- n_head = self.gguf_file_meta['llama.attention.head_count_kv']
- values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
- .swapaxes(1, 2)
- .reshape(values.shape))
- return values
def read_value(f, data_type):
if data_type == DATA_TYPES["string"]:
@@ -921,6 +689,7 @@ def translate_name_to_gguf(name):
name = name.replace(".gate_up_proj.", ".up_proj")
name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp")
+ name = name.replace(".mlp.gate.e_score_correction_bias", ".exp_probs_b.bias")
name = name.replace(".mlp.gate", ".ffn_gate_inp")
name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp")
name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py
index ecc09a0..003f93c 100644
--- a/ktransformers/util/custom_loader.py
+++ b/ktransformers/util/custom_loader.py
@@ -7,15 +7,39 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
-import KTransformersOps
+if not torch.xpu.is_available():
+ import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
+from ktransformers.util.custom_gguf import *
from safetensors.torch import save_file
+from abc import ABC, abstractmethod
+from typing import Dict, Any, Optional, Union
-class SafeTensorLoader:
+class ModelLoader(ABC):
+ """
+ Abstract base class for model loaders.
+ Defines the interface that all model loaders must implement.
+ """
tensor_file_map = {}
- tensor_type_map = {}
- file_handle_map = {}
+ @abstractmethod
+ def has_tensor(cls, name: str):
+ """
+ Check if the tensor exists in the loader.
+
+ Args:
+ name: Name of the tensor to check
+
+ Returns:
+ bool: True if the tensor exists, False otherwise
+ """
+ pass
+
+class SafeTensorLoader(ModelLoader):
+ tensor_file_map: dict
+ tensor_type_map: dict
+ file_handle_map: dict
+ tensor_device_map: dict
def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path)
@@ -28,6 +52,10 @@ class SafeTensorLoader:
folder_path = os.path.dirname(file_path)
else:
folder_path = file_path
+ self.file_handle_map = {}
+ self.tensor_file_map = {}
+ self.tensor_type_map = {}
+ self.tensor_device_map = {}
found_safetensor = False
for root, _, files in os.walk(folder_path):
@@ -57,7 +85,11 @@ class SafeTensorLoader:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str="cpu"):
- if key not in self.tensor_file_map:
+ if translate_name_to_gguf(key) in self.tensor_file_map:
+ key = translate_name_to_gguf(key)
+ elif key in self.tensor_file_map:
+ pass
+ else:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
@@ -66,13 +98,145 @@ class SafeTensorLoader:
tensor = f.get_tensor(key)
return tensor.to(device)
+ def load_experts(self, key: str, device: str="cpu"):
+ '''
+ Load experts from safetensor
+ key: the name of the experts
+ device: the device to load the experts to
+ return: dict,
+ {up: tensor, down: tensor, gate: tensor, up_type: int, down_type: int, gate_type: int}
+ {xxx}_type: the type of the up tensor, corresponding to the ggml type
+ '''
+ if self.has_tensor(translate_name_to_gguf(key)+".ffn_gate_exps.weight"):
+ # legacy branch for loading hybrid model
+ base_key = translate_name_to_gguf(key)
+ # Load experts from safetensor
+ gate_key = f"{base_key}.ffn_gate_exps.weight"
+ gate_type_key = f"{base_key}.ffn_gate_exps.ggml_type"
+ up_key = f"{base_key}.ffn_up_exps.weight"
+ up_type_key = f"{base_key}.ffn_up_exps.ggml_type"
+ down_key = f"{base_key}.ffn_down_exps.weight"
+ down_type_key = f"{base_key}.ffn_down_exps.ggml_type"
+ gate_tensor = self.load_tensor(gate_key, device).numpy()
+ up_tensor = self.load_tensor(up_key, device).numpy()
+ down_tensor = self.load_tensor(down_key, device).numpy()
+ gate_type = self.load_tensor(gate_type_key, device).item()
+ up_type = self.load_tensor(up_type_key, device).item()
+ down_type = self.load_tensor(down_type_key, device).item()
+
+ return {
+ "up": up_tensor,
+ "gate": gate_tensor,
+ "down": down_tensor,
+ "up_type": up_type,
+ "gate_type": gate_type,
+ "down_type": down_type
+ }
+
+ else:
+ # Load experts from safetensor
+ base_key = key # e.g. "model.layers.3.mlp.experts"
+ experts_count = 0
+
+ # First, count how many experts we have by checking for expert 0's up_proj
+ while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight"):
+ experts_count += 1
+
+ if experts_count == 0:
+ raise ValueError(f"No experts found for key {base_key}")
+
+ # Initialize empty lists to store tensors for each projection type
+ up_projs = []
+ gate_projs = []
+ down_projs = []
+
+ # Load all expert weights
+ for expert_id in range(experts_count):
+ up_key = f"{base_key}.{expert_id}.up_proj.weight"
+ gate_key = f"{base_key}.{expert_id}.gate_proj.weight"
+ down_key = f"{base_key}.{expert_id}.down_proj.weight"
+
+ up_tensor = self.load_tensor(up_key, device)
+ gate_tensor = self.load_tensor(gate_key, device)
+ down_tensor = self.load_tensor(down_key, device)
+
+ up_projs.append(up_tensor)
+ gate_projs.append(gate_tensor)
+ down_projs.append(down_tensor)
+
+ # Stack the tensors along a new dimension
+ up_tensor = torch.stack(up_projs, dim=0)
+ gate_tensor = torch.stack(gate_projs, dim=0)
+ down_tensor = torch.stack(down_projs, dim=0)
+
+ # Get original dtype for GGML type determination
+ orig_up_dtype = up_tensor.dtype
+ orig_gate_dtype = gate_tensor.dtype
+ orig_down_dtype = down_tensor.dtype
+
+ # Convert to numpy with proper bfloat16 support
+ up_numpy = up_tensor.view(torch.uint16).numpy()
+ gate_numpy = gate_tensor.view(torch.uint16).numpy()
+ down_numpy = down_tensor.view(torch.uint16).numpy()
+
+ # Determine tensor data types for GGML conversion
+ def get_ggml_type(dtype):
+ if dtype == torch.float32:
+ return GGMLQuantizationType.F32
+ elif dtype == torch.float16:
+ return GGMLQuantizationType.F16
+ elif dtype == torch.bfloat16:
+ return GGMLQuantizationType.BF16
+ else:
+ raise ValueError(f"Unsupported tensor dtype: {dtype}")
+
+ return {
+ "up": up_numpy,
+ "gate": gate_numpy,
+ "down": down_numpy,
+ "up_type": get_ggml_type(orig_up_dtype),
+ "gate_type": get_ggml_type(orig_gate_dtype),
+ "down_type": get_ggml_type(orig_down_dtype)
+ }
+
+ def load_gate(self, key: str, device: str="cpu"):
+ '''
+ Load gate from safetensor
+ key: the name of the gate
+ device: the device to load the gate to
+ return: dict,
+ {'weight': tensor, 'e_score_correction_bias': tensor}
+ '''
+ target = ["weight", "e_score_correction_bias"]
+ res = {'weight': None, 'e_score_correction_bias': None}
+ if self.has_tensor(translate_name_to_gguf(key)+".ffn_gate_exps.weight"):
+ # legacy branch for loading hybrid model
+ base_key = key
+ for k in target:
+ translated_key = translate_name_to_gguf(f"{base_key}.{k}")
+ if self.has_tensor(translated_key):
+ tensor = self.load_tensor(translated_key, device)
+ res[k] = tensor
+ else:
+ # Load gate from safetensor
+ base_key = key
+ for k in target:
+ if self.has_tensor(f"{base_key}.{k}"):
+ tensor = self.load_tensor(f"{base_key}.{k}", device)
+ res[k] = tensor
+ return res
+
def close_all_handles(self):
for handle in self.file_handle_map.values():
handle.close()
self.file_handle_map.clear()
def load_dequantized_tensor(self, key:str, device: str="cpu"):
- if key not in self.tensor_file_map:
+ if key in self.tensor_file_map and translate_name_to_gguf(key):
+ pass
+ elif translate_name_to_gguf(key) in self.tensor_file_map:
+ key = translate_name_to_gguf(key)
+ else:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
@@ -83,4 +247,320 @@ class SafeTensorLoader:
if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
tensor = weight_dequant(tensor, weight_scale_inv)
- return tensor.to(device)
\ No newline at end of file
+ return tensor.to(device)
+
+ def has_tensor(self, name: str):
+ return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map
+
+class GGUFLoader(ModelLoader):
+ tensor_info: dict
+ gguf_path: str
+ tensor_file_map: dict # {tensor_name: tensor_file_path}
+ gguf_file_meta: dict
+ safetensor_loader: SafeTensorLoader
+ def __init__(self, gguf_path: str):
+ # Check dir exist
+ if not os.path.exists(gguf_path):
+ raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
+ if os.path.isfile(gguf_path):
+ gguf_path = os.path.dirname(gguf_path)
+
+ self.safetensor_loader = None
+
+ self.tensor_info = {}
+ self.gguf_path = gguf_path
+ self.tensor_file_map = {}
+ self.file_data_map = {}
+ self.gguf_file_meta = {}
+ self.tensor_device_map = {}
+
+ # Walk through all the .gguf files in the directory
+ found_gguf = False
+ for root, dirs, files in os.walk(gguf_path):
+ for file in files:
+ if file.endswith(".gguf"):
+ found_gguf = True
+ file_name = os.path.join(root, file)
+ with open(file_name, "rb") as f:
+ self.load_gguf(f)
+ if file_name not in self.file_data_map:
+ self.file_data_map[file_name] = np.memmap(file_name, mode = 'r')
+ if not found_gguf:
+ raise FileNotFoundError(f"Cannot find any .gguf files in: {gguf_path}")
+
+ def load_gguf(self, f):
+ f.seek(0)
+ assert f.read(4) == b'GGUF'
+ values = struct.unpack("torch.Tensor:
+ name = translate_name_to_gguf(name)
+ t = self.tensor_info[name]
+ shape = t["shape"]
+ ggml_type = t["ggml_type"]
+ if ggml_type not in GGML_NAMES:
+ raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
+ ggml_name = GGML_NAMES[ggml_type]
+
+ # TODO: experts may fused in quant block, split it
+ assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant"
+
+ blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name]
+ block_size = GGML_BLOCK_SIZES[ggml_name]
+ offset = expert_id * block_size * blocks_per_experts
+ data = data[offset: offset + block_size * blocks_per_experts]
+
+ if "cuda" in device.lower():
+ values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
+ else:
+ values = GGML_DEQUANTIZE[ggml_name](data)
+ values = torch.from_numpy(values.copy())
+
+ if ggml_name == "BF16":
+ values = values.view(torch.bfloat16)
+ values = values.view(shape[-2::-1])
+
+ return values
+
+ def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
+ name = translate_name_to_gguf(name)
+ t = self.tensor_info[name]
+ if target_dtype == None:
+ target_dtype = torch.get_default_dtype()
+
+ shape = t["shape"]
+ ggml_type = t["ggml_type"]
+
+ if ggml_type not in GGML_NAMES:
+ raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
+
+ ggml_name = GGML_NAMES[ggml_type]
+
+ data = self.get_mmap_tensor(name)
+
+ block_size = GGML_BLOCK_SIZES[ggml_name]
+ elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
+ num_elements = int(np.prod(shape))
+ num_blocks = num_elements // elements_per_block
+
+ blocks_per_iter = 16384
+ if num_blocks > blocks_per_iter: # dequant large tensor
+ values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
+ for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
+ blocks_begin = i * blocks_per_iter
+ blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
+ if "cuda" in device.lower():
+ try:
+ cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
+ except:
+ cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
+ cur_values = torch.from_numpy(cur_values.copy()).to(device)
+ else:
+ cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
+ cur_values = torch.from_numpy(cur_values.copy())
+
+ cur_values = cur_values.view(-1, elements_per_block)
+ if ggml_name == "BF16":
+ cur_values = cur_values.view(torch.bfloat16)
+ values[blocks_begin : blocks_end] = cur_values
+ else:
+ if "cuda" in device.lower():
+ values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
+ else:
+ np_values = np.copy(GGML_DEQUANTIZE[ggml_name](data))
+ values = torch.from_numpy(np_values).to(device)
+ del np_values
+
+ if ggml_name == "BF16":
+ values = values.view(torch.bfloat16)
+
+
+ values = values.view(shape[::-1])
+ if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
+ n_head = self.gguf_file_meta['llama.attention.head_count']
+ values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
+ .swapaxes(1, 2)
+ .reshape(values.shape))
+ elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
+ n_head = self.gguf_file_meta['llama.attention.head_count_kv']
+ values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
+ .swapaxes(1, 2)
+ .reshape(values.shape))
+ return values
+ def has_tensor(self, name: str):
+ name = translate_name_to_gguf(name)
+ return name in self.tensor_info
+
+ def get_ggml_type(self, name: str):
+ name = translate_name_to_gguf(name)
+ if name not in self.tensor_info:
+ raise KeyError(f"Key {name} not found in GGUF files")
+ return self.tensor_info[name]["ggml_type"]
+
+class ModelLoaderFactory:
+ """
+ Factory class for creating model loaders.
+ Automatically detects the model format based on file extensions in the directory.
+ """
+
+ @staticmethod
+ def create_loader(path: str):
+ """
+ Create a model loader for the given path by detecting the model format.
+ The function checks for the presence of .safetensors or .gguf files
+ in the specified path and creates the appropriate loader.
+
+ Args:
+ path: Path to the model directory or file
+
+ Returns:
+ An appropriate ModelLoader instance (SafeTensorLoader or GGUFLoader)
+
+ Raises:
+ FileNotFoundError: If no supported model files are found in the path
+ """
+ if not os.path.exists(path):
+ raise FileNotFoundError(f"Path not found: {path}")
+
+ # Normalize to directory path if a file was provided
+ if os.path.isfile(path):
+ if path.endswith(".safetensors"):
+ return SafeTensorLoader(path)
+ elif path.endswith(".gguf"):
+ return GGUFLoader(path)
+ else:
+ folder_path = os.path.dirname(path)
+ else:
+ folder_path = path
+
+ # Check for safetensors files
+ has_safetensors = False
+ has_gguf = False
+
+ for root, _, files in os.walk(folder_path):
+ for file in files:
+ if file.endswith(".safetensors"):
+ has_safetensors = True
+ break
+ elif file.endswith(".gguf"):
+ has_gguf = True
+ break
+ if has_safetensors or has_gguf:
+ break
+
+ # Create the appropriate loader based on detected file types
+ # Prioritize SafeTensor over GGUF if both are present
+ if has_safetensors:
+ try:
+ return SafeTensorLoader(folder_path)
+ except Exception as e:
+ print(f"Failed to create SafeTensorLoader: {e}")
+ # Fall through to try GGUF if SafeTensor fails
+ if not has_gguf:
+ raise
+
+ if has_gguf:
+ try:
+ return GGUFLoader(folder_path)
+ except Exception as e:
+ print(f"Failed to create GGUFLoader: {e}")
+ raise
+
+ # No supported model files found
+ raise FileNotFoundError(f"No .safetensors or .gguf files found in: {folder_path}")
\ No newline at end of file
diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py
index 30f8880..98a44f2 100644
--- a/ktransformers/util/utils.py
+++ b/ktransformers/util/utils.py
@@ -11,13 +11,24 @@ from torch import nn
import itertools
import time
import enum
-from ktransformers.util.custom_gguf import translate_name_to_gguf
-from ktransformers.util.custom_gguf import GGUFLoader
+from transformers import (
+ LogitsProcessorList,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ MinPLogitsWarper,
+ TypicalLogitsWarper,
+ EpsilonLogitsWarper,
+ EtaLogitsWarper,
+)
+
+from ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, GGUFLoader, translate_name_to_gguf
from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
-from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
+if not torch.xpu.is_available():
+ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
warm_uped = False
@@ -49,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
return min_compute_capability_major
else:
return torch.cuda.get_device_properties(device)
+ else:
+ return 0
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
@@ -87,45 +100,132 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list = list(all_device_list)
return all_device_list
-def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
+def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
- translated_key = translate_name_to_gguf(key)
+ translated_key = key
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
- if gguf_loader.safetensor_loader is not None:
- load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
- tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
+ if isinstance(gguf_loader, SafeTensorLoader):
+ load_dequantized_tensor = gguf_loader.load_dequantized_tensor
else:
load_dequantized_tensor = gguf_loader.load_gguf_tensor
tensor_file_map = gguf_loader.tensor_file_map
- if translated_key in tensor_file_map:
+ if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
- torch.cuda.empty_cache()
- weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
- set_param(module, name, weights)
- del weights
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif torch.xpu.is_available():
+ torch.xpu.empty_cache()
+ if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key):
+ attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype)
+ attn_k_b = attn_k_b.transpose(1, 2).contiguous()
+ attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype)
+ kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)
+ kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()
+ set_param(module, name, kv_b_proj)
+ del attn_k_b
+ del attn_v_b
+ else:
+ weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
+ set_param(module, name, weights)
+ del weights
else:
#print(load_config.tensor_file_map.keys())
raise Exception(f"can't find {translated_key} in GGUF file!")
-def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
+
+def sync_all_device(all_device_list):
+ for device in all_device_list:
+ if "cuda" in device.lower():
+ torch.cuda.synchronize(device)
+ elif "xpu" in device.lower():
+ torch.xpu.synchronize(device)
+ else:
+ raise RuntimeError("The device {} is not available".format(device))
+
+torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
+
+def xpu_fp16_model(config):
+ # This function is to check if we run this model on XPU with FP16 dtype
+ if not torch.xpu.is_available():
+ return False
+ if config.architectures[0] == "DeepseekV3ForCausalLM":
+ return True
+ if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096:
+ # Qwen3-30B seems have precision issue with FP16
+ # so we only use FP16 for Qwen3-235B now
+ return True
+ return False
+
+def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
#print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule):
- load_cur_state_dict(module, gguf_loader, prefix)
+ load_cur_state_dict(module, gguf_loader, prefix, device=device)
for name, child in module._modules.items():
- load_weights(child, gguf_loader, prefix+name+".")
+ load_weights(child, gguf_loader, prefix+name+".", device=device)
else:
module.load()
+def tf_logits_warper(generation_config):
+ """
+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
+ used for multinomial sampling.
+ """
+
+ # instantiate warpers list
+ warpers = LogitsProcessorList()
+
+ # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
+ # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
+ if generation_config.num_beams > 1:
+ if isinstance(generation_config._eos_token_tensor, list):
+ min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
+ elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
+ min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
+ else:
+ min_tokens_to_keep = 2
+ else:
+ min_tokens_to_keep = 1
+
+ # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
+ # all samplers can be found in `generation_utils_samplers.py`
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
+ warpers.append(TemperatureLogitsWarper(generation_config.temperature))
+ if generation_config.top_k is not None and generation_config.top_k != 0:
+ warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.top_p is not None and generation_config.top_p < 1.0:
+ warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.min_p is not None:
+ # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
+ warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
+ warpers.append(
+ TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
+ )
+ if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
+ warpers.append(
+ EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
+ )
+ if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
+ warpers.append(
+ EtaLogitsWarper(
+ epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
+ )
+ )
+ # `LogitNormalization` should always be the last logit processor, when present
+ if generation_config.renormalize_logits is True:
+ warpers.append(LogitNormalization())
+ return warpers
+
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
@@ -134,8 +234,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape
device_map = model.gguf_loader.tensor_device_map
- torch_device = get_device('blk.0.self_attn', device_map)
- torch_device = "cuda:0" if torch_device == "cuda" else torch_device
+ torch_device = get_device('model.layers.0.self_attn', device_map)
+ torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
@@ -148,7 +248,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
- torch.cuda.set_device(torch_device)
+ if torch.cuda.is_available():
+ torch.cuda.set_device(torch_device)
+ elif torch.xpu.is_available():
+ torch.xpu.set_device(torch_device)
+ else:
+ raise RuntimeError(f"The device: {torch_device} is not available")
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
@@ -156,10 +261,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
- if past_key_values != None:
+ if past_key_values != None and isinstance(past_key_values, StaticCache):
past_key_values.change_seq_length(1)
- for device in all_cuda_device:
- torch.cuda.synchronize(device)
+ sync_all_device(all_cuda_device)
#print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
@@ -185,11 +289,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
return logits
- torch.cuda.set_device(torch_device)
+ if torch.cuda.is_available():
+ torch.cuda.set_device(torch_device)
+ elif torch.xpu.is_available():
+ torch.xpu.set_device(torch_device)
+ else:
+ raise RuntimeError(f"The device: {torch_device} is not available")
with torch.no_grad():
stream = TextStreamer(tokenizer)
- if mode != 'long_context':
+ if torch.xpu.is_available():
+ from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache
+ if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
+ past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
+ else:
+ past_key_values = DynamicNormalCache.from_legacy_cache(None)
+ elif mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
@@ -201,14 +316,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
)
- try: # transformers==4.43
- logits_warper = (
- model._get_logits_warper(generation_config,device=inputs.device)
- )
- except:
- logits_warper = (
- model._get_logits_warper(generation_config)
- )
+
+ logits_warper = tf_logits_warper(generation_config)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
generated_ids = torch.zeros(
diff --git a/ktransformers/util/weight_loader.py b/ktransformers/util/weight_loader.py
new file mode 100644
index 0000000..9dda646
--- /dev/null
+++ b/ktransformers/util/weight_loader.py
@@ -0,0 +1,367 @@
+from abc import ABC, abstractmethod
+import os
+import torch
+import numpy as np
+from safetensors import safe_open
+from typing import Dict, Any, Optional, Union
+
+class ModelLoader(ABC):
+ """
+ Abstract base class for model loaders.
+ Defines the interface that all model loaders must implement.
+ """
+
+ @abstractmethod
+ def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor:
+ """
+ Load a tensor by name.
+
+ Args:
+ name: Name of the tensor to load
+ device: Device to load the tensor to
+
+ Returns:
+ The loaded tensor
+ """
+ pass
+
+ @classmethod
+ @abstractmethod
+ def supports_format(cls, path: str) -> bool:
+ """
+ Check if this loader supports the given path format.
+
+ Args:
+ path: Path to check
+
+ Returns:
+ True if this loader supports the given path, False otherwise
+ """
+ pass
+
+
+class SafeTensorLoader(ModelLoader):
+ """
+ Loader for SafeTensor format models.
+ """
+
+ def __init__(self, path: str):
+ """
+ Initialize the SafeTensor loader.
+
+ Args:
+ path: Path to the model directory or file
+ """
+ self.tensor_file_map = {} # Maps tensor names to file paths
+ self.file_handle_map = {} # Maps file names to file handles
+ self._load_tensor_file_map(path)
+
+ def _load_tensor_file_map(self, path: str) -> None:
+ """
+ Load the tensor file map from the given path.
+
+ Args:
+ path: Path to the model directory or file
+ """
+ # Normalize path to directory
+ if not os.path.exists(path):
+ raise FileNotFoundError(f"Path not found: {path}")
+ if os.path.isfile(path):
+ folder_path = os.path.dirname(path)
+ else:
+ folder_path = path
+
+ found_safetensor = False
+ for root, _, files in os.walk(folder_path):
+ files = sorted(files)
+ for file in files:
+ if file.endswith(".safetensors"):
+ found_safetensor = True
+ file_path = os.path.join(root, file)
+ if file not in self.file_handle_map:
+ try:
+ handle = safe_open(file_path, framework="pt")
+ self.file_handle_map[file] = handle
+ except Exception as e:
+ print(f"Error opening Safetensor file {file_path}: {e}")
+ continue
+
+ f = self.file_handle_map.get(file)
+ if f is None:
+ continue
+ try:
+ for key in f.keys():
+ self.tensor_file_map[key] = file
+ except Exception as e:
+ print(f"Error reading Safetensor file {file_path}: {e}")
+
+ if not found_safetensor:
+ # Not raising an error here allows for the factory to try other loaders
+ print(f"No Safetensor files found in {folder_path}")
+
+ def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor:
+ """
+ Load a tensor by name.
+
+ Args:
+ name: Name of the tensor to load
+ device: Device to load the tensor to
+
+ Returns:
+ The loaded tensor
+ """
+ if name not in self.tensor_file_map:
+ raise KeyError(f"Key {name} not found in Safetensor files")
+ file = self.tensor_file_map[name]
+ f = self.file_handle_map.get(file)
+ if f is None:
+ raise FileNotFoundError(f"File {file} not found in Safetensor files")
+ tensor = f.get_tensor(name)
+ return tensor.to(device)
+
+ def load_dequantized_tensor(self, name: str, device: str = "cpu") -> torch.Tensor:
+ """
+ Load and dequantize a tensor.
+
+ Args:
+ name: Name of the tensor to load
+ device: Device to load the tensor to
+
+ Returns:
+ The dequantized tensor
+ """
+ if name not in self.tensor_file_map:
+ raise KeyError(f"Key {name} not found in Safetensor files")
+ file = self.tensor_file_map[name]
+ f = self.file_handle_map.get(file)
+ if f is None:
+ raise FileNotFoundError(f"File {file} not found in Safetensor files")
+ tensor = f.get_tensor(name).to(device)
+ if name.endswith(".weight"):
+ if name[:-7] + ".weight_scale_inv" in self.tensor_file_map:
+ weight_scale_inv = f.get_tensor(name[:-7] + ".weight_scale_inv").to(device)
+ # Assuming weight_dequant function is imported
+ from ktransformers.ktransformers_ext.triton.fp8gemm import weight_dequant
+ tensor = weight_dequant(tensor, weight_scale_inv)
+ return tensor.to(device)
+
+ def close_all_handles(self) -> None:
+ """
+ Close all file handles.
+ """
+ for handle in self.file_handle_map.values():
+ handle.close()
+ self.file_handle_map.clear()
+
+ @classmethod
+ def supports_format(cls, path: str) -> bool:
+ """
+ Check if this loader supports the given path format.
+
+ Args:
+ path: Path to check
+
+ Returns:
+ True if safetensor files are found in the path, False otherwise
+ """
+ # Normalize path to directory
+ if not os.path.exists(path):
+ return False
+ if os.path.isfile(path):
+ if path.endswith(".safetensors"):
+ return True
+ folder_path = os.path.dirname(path)
+ else:
+ folder_path = path
+
+ # Check if any safetensor files exist in the folder
+ for root, _, files in os.walk(folder_path):
+ for file in files:
+ if file.endswith(".safetensors"):
+ return True
+ return False
+
+
+class GGUFLoader(ModelLoader):
+ """
+ Loader for GGUF format models.
+ """
+
+ def __init__(self, path: str):
+ """
+ Initialize the GGUF loader.
+
+ Args:
+ path: Path to the model directory or file
+ """
+ # Check if path exists
+ if not os.path.exists(path):
+ raise FileNotFoundError(f"GGUF dir not found: {path}")
+ if os.path.isfile(path):
+ self.gguf_path = os.path.dirname(path)
+ else:
+ self.gguf_path = path
+
+ self.tensor_info = {} # Stores tensor metadata
+ self.tensor_file_map = {} # Maps tensor names to file paths
+ self.file_data_map = {} # Maps file paths to memory-mapped data
+ self.gguf_file_meta = {} # Stores GGUF metadata
+
+ # For compatibility with the factory pattern
+ self.safetensor_loader = None
+
+ # Scan all GGUF files in the directory
+ found_gguf = False
+ for root, _, files in os.walk(self.gguf_path):
+ for file in files:
+ if file.endswith(".gguf"):
+ found_gguf = True
+ file_path = os.path.join(root, file)
+ with open(file_path, "rb") as f:
+ self._load_gguf(f)
+ if file_path not in self.file_data_map:
+ self.file_data_map[file_path] = np.memmap(file_path, mode='r')
+
+ if not found_gguf:
+ raise FileNotFoundError(f"Cannot find any .gguf files in: {self.gguf_path}")
+
+ def _load_gguf(self, f) -> None:
+ """
+ Load GGUF file metadata and tensor info.
+
+ Args:
+ f: File handle of the GGUF file
+ """
+ # Implementation should follow the original GGUFLoader._load_gguf
+ # This is a simplified version for illustration
+ f.seek(0)
+ assert f.read(4) == b'GGUF'
+
+ # Read header
+ values = struct.unpack(" Any:
+ """
+ Read a value from the file according to its data type.
+
+ Args:
+ f: File handle
+ data_type: Type of data to read
+
+ Returns:
+ The read value
+ """
+ # Simplified implementation
+ # In a complete implementation, this would handle all data types
+ if data_type == 8: # DATA_TYPES["string"]
+ length = struct.unpack(" torch.Tensor:
+ """
+ Load a tensor by name.
+
+ Args:
+ name: Name of the tensor to load
+ device: Device to load the tensor to
+
+ Returns:
+ The loaded tensor
+ """
+ # This should call load_gguf_tensor with the appropriate parameters
+ return self.load_gguf_tensor(name, device)
+
+ def load_gguf_tensor(self, name: str, device: str = "cpu", target_dtype = None) -> torch.Tensor:
+ """
+ Load a GGUF tensor by name.
+
+ Args:
+ name: Name of the tensor to load
+ device: Device to load the tensor to
+ target_dtype: Target data type for the tensor
+
+ Returns:
+ The loaded tensor
+ """
+ # Implementation would follow the original GGUFLoader.load_gguf_tensor
+ # This is a placeholder for illustration
+ if name not in self.tensor_info:
+ raise KeyError(f"Tensor {name} not found")
+
+ # Actual implementation would dequantize the tensor data
+ # and return a torch.Tensor
+ return torch.zeros(1, device=device) # Placeholder
+
+ @classmethod
+ def supports_format(cls, path: str) -> bool:
+ """
+ Check if this loader supports the given path format.
+
+ Args:
+ path: Path to check
+
+ Returns:
+ True if GGUF files are found in the path, False otherwise
+ """
+ # Normalize path to directory
+ if not os.path.exists(path):
+ return False
+ if os.path.isfile(path):
+ return path.endswith(".gguf")
+
+ # Check if any GGUF files exist in the folder
+ for root, _, files in os.walk(path):
+ for file in files:
+ if file.endswith(".gguf"):
+ return True
+ return False
\ No newline at end of file
diff --git a/merge_tensors/merge_safetensor_gguf.py b/merge_tensors/merge_safetensor_gguf.py
index efeab3b..f299ab9 100644
--- a/merge_tensors/merge_safetensor_gguf.py
+++ b/merge_tensors/merge_safetensor_gguf.py
@@ -6,7 +6,7 @@ import sys
# sys.path.insert(0, "/home/azure/ktransformers")
import argparse
import torch
-from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
+from ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf
from safetensors import safe_open
from safetensors.torch import save_file
import re
diff --git a/pyproject.toml b/pyproject.toml
index fbf0924..9502c55 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,7 +30,6 @@ dependencies = [
"build",
"fire",
"protobuf",
- "triton >= 3.2"
]
requires-python = ">=3.10"
@@ -70,7 +69,7 @@ ktransformers = "ktransformers.server.main:main"
[tool.setuptools.packages.find]
where = ["./", ]
-include = ["ktransformers"]
+include = ["ktransformers","ktransformers.*"]
[tool.black]
line-length = 120
preview = true
diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt
index dd3a206..25afaef 100644
--- a/requirements-local_chat.txt
+++ b/requirements-local_chat.txt
@@ -7,4 +7,3 @@ cpufeature; sys_platform == 'win32' or sys_platform == 'Windows'
protobuf
tiktoken
blobfile
-triton>=3.2
diff --git a/setup.py b/setup.py
index 8ea6797..25e8a35 100644
--- a/setup.py
+++ b/setup.py
@@ -39,7 +39,17 @@ try:
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None
-
+KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available()
+
+# 检测 DEV_BACKEND 环境变量
+dev_backend = os.environ.get("DEV_BACKEND", "").lower()
+if dev_backend == "xpu":
+ triton_dep = [
+ "pytorch-triton-xpu==3.3.0"
+ ]
+else:
+ triton_dep = ["triton>=3.2"]
+
with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"
class CpuInstructInfo:
@@ -241,8 +251,10 @@ class VersionInfo:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
elif ROCM_HOME is not None:
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
+ elif torch.xpu.is_available():
+ backend_version = f"xpu"
else:
- raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
+ raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set and XPU is not available.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
if full_version:
return package_version
@@ -511,8 +523,10 @@ class CMakeBuild(BuildExtension):
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
elif ROCM_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
+ elif KTRANSFORMERS_BUILD_XPU:
+ cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"]
else:
- raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.")
+ raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.")
cmake_args = get_cmake_abi_args(cmake_args)
# log cmake_args
@@ -636,33 +650,41 @@ elif MUSA_HOME is not None:
]
}
)
+elif torch.xpu.is_available(): #XPUExtension is not available now.
+ ops_module = None
else:
- raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
+ raise ValueError("Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.")
-ext_modules = [
- CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
- ops_module,
- CUDAExtension(
- 'vLLMMarlin', [
- 'csrc/custom_marlin/binding.cpp',
- 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
- 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
- ],
- extra_compile_args={
- 'cxx': ['-O3'],
- 'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
- },
- )
-]
-if with_balance:
- print("using balance_serve")
- ext_modules.append(
- CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
- )
+if not torch.xpu.is_available():
+ ext_modules = [
+ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
+ ops_module,
+ CUDAExtension(
+ 'vLLMMarlin', [
+ 'csrc/custom_marlin/binding.cpp',
+ 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
+ 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
+ ],
+ extra_compile_args={
+ 'cxx': ['-O3'],
+ 'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
+ },
+ )
+ ]
+ if with_balance:
+ print("using balance_serve")
+ ext_modules.append(
+ CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
+ )
+else:
+ ext_modules = [
+ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
+ ]
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
+ install_requires=triton_dep,
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)