mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
Merge pull request #1295 from rnwang04/xpu_support
Enable ktransformers on Intel GPU with local chat backend
This commit is contained in:
commit
f7ee993fdc
22 changed files with 673 additions and 81 deletions
|
@ -23,6 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
||||||
|
|
||||||
<h2 id="Updates">🔥 Updates</h2>
|
<h2 id="Updates">🔥 Updates</h2>
|
||||||
|
|
||||||
|
**May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).
|
||||||
|
|
||||||
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
|
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2
|
https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2
|
||||||
|
|
|
@ -41,6 +41,7 @@ option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW
|
||||||
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON)
|
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON)
|
||||||
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
||||||
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
|
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
|
||||||
|
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
|
||||||
|
|
||||||
# Architecture specific
|
# Architecture specific
|
||||||
# TODO: probably these flags need to be tweaked on some architectures
|
# TODO: probably these flags need to be tweaked on some architectures
|
||||||
|
@ -303,6 +304,8 @@ elseif (UNIX)
|
||||||
message(STATUS "MUSA Toolkit found")
|
message(STATUS "MUSA Toolkit found")
|
||||||
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
|
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
|
||||||
endif()
|
endif()
|
||||||
|
elseif (KTRANSFORMERS_USE_XPU)
|
||||||
|
add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
|
||||||
else()
|
else()
|
||||||
find_package(CUDA REQUIRED)
|
find_package(CUDA REQUIRED)
|
||||||
include_directories("${CUDA_INCLUDE_DIRS}")
|
include_directories("${CUDA_INCLUDE_DIRS}")
|
||||||
|
@ -361,6 +364,7 @@ elseif(UNIX)
|
||||||
message(STATUS "Building for HIP")
|
message(STATUS "Building for HIP")
|
||||||
elseif(KTRANSFORMERS_USE_MUSA)
|
elseif(KTRANSFORMERS_USE_MUSA)
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
||||||
|
elseif(KTRANSFORMERS_USE_XPU)
|
||||||
else()
|
else()
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
|
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <stdexcept>
|
||||||
#ifdef KTRANSFORMERS_USE_CUDA
|
#ifdef KTRANSFORMERS_USE_CUDA
|
||||||
#include "vendors/cuda.h"
|
#include "vendors/cuda.h"
|
||||||
#elif KTRANSFORMERS_USE_MUSA
|
#elif KTRANSFORMERS_USE_MUSA
|
||||||
|
@ -66,10 +67,14 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
|
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
|
||||||
|
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)
|
||||||
void (*func)(void*) = (void (*)(void*))params.first;
|
void (*func)(void*) = (void (*)(void*))params.first;
|
||||||
void* args = (void*)params.second;
|
void* args = (void*)params.second;
|
||||||
*((CPUInfer**)args) = this;
|
*((CPUInfer**)args) = this;
|
||||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
|
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
|
||||||
|
#else
|
||||||
|
throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma");
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void sync_(void* cpu_infer_ptr) {
|
static void sync_(void* cpu_infer_ptr) {
|
||||||
|
@ -78,7 +83,11 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
|
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
|
||||||
|
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)
|
||||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
|
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
|
||||||
|
#else
|
||||||
|
throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma");
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
**/
|
**/
|
||||||
// Python bindings
|
// Python bindings
|
||||||
#include "cpu_backend/cpuinfer.h"
|
#include "cpu_backend/cpuinfer.h"
|
||||||
#ifndef KTRANSFORMERS_USE_ROCM
|
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)
|
||||||
#include "device_launch_parameters.h"
|
#include "device_launch_parameters.h"
|
||||||
#endif
|
#endif
|
||||||
#include "llamafile/flags.h"
|
#include "llamafile/flags.h"
|
||||||
|
|
|
@ -21,7 +21,7 @@ interface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified
|
||||||
Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.
|
Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.
|
||||||
|
|
||||||
<h2 id="Updates">🔥 Updates</h2>
|
<h2 id="Updates">🔥 Updates</h2>
|
||||||
|
* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./en/xpu.md)).
|
||||||
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./en/llama4.md)).
|
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./en/llama4.md)).
|
||||||
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./en/balance-serve.md)).
|
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./en/balance-serve.md)).
|
||||||
* **Mar 27, 2025**: Support Multi-concurrency.
|
* **Mar 27, 2025**: Support Multi-concurrency.
|
||||||
|
|
117
doc/en/xpu.md
Normal file
117
doc/en/xpu.md
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
# Intel GPU Support for KTransformers (Beta)
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
We are excited to introduce **Intel GPU support** in KTransformers (Beta release). This implementation has been tested and developed using Intel Xeon Scalable processors and Intel Arc GPU's (such as A770 and B580).
|
||||||
|
|
||||||
|
## Installation Guide
|
||||||
|
|
||||||
|
### 1. Install Intel GPU Driver
|
||||||
|
Begin by installing the GPU drivers for your Intel GPU:
|
||||||
|
- [Official GPU Installation Guide for Intel GPUs](https://dgpu-docs.intel.com/driver/overview.html)
|
||||||
|
|
||||||
|
> [!Important]
|
||||||
|
> Ensure that **Resizable BAR** is enabled in your system's BIOS before proceeding. This is essential for optimal GPU performance and to avoid potential issues such as `Bus error (core dumped)`. For detailed steps, please refer to the official guidance [here](https://www.intel.com/content/www/us/en/support/articles/000090831/graphics.html).
|
||||||
|
|
||||||
|
### 2. Set Up Conda Environment
|
||||||
|
We recommend using Miniconda3/Anaconda3 for environment management:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download Miniconda
|
||||||
|
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||||
|
|
||||||
|
# Create environment
|
||||||
|
conda create --name ktransformers python=3.11
|
||||||
|
conda activate ktransformers
|
||||||
|
|
||||||
|
# Install required libraries
|
||||||
|
conda install -c conda-forge libstdcxx-ng
|
||||||
|
|
||||||
|
# Verify GLIBCXX version (should include 3.4.32)
|
||||||
|
strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
|
||||||
|
|
||||||
|
### 3. Install PyTorch and IPEX-LLM
|
||||||
|
Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu
|
||||||
|
pip uninstall torch torchvision torchaudio
|
||||||
|
pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7
|
||||||
|
pip install packaging ninja cpufeature numpy
|
||||||
|
pip uninstall intel-opencl-rt dpcpp-cpp-rt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Build ktransformers
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone repository
|
||||||
|
git clone https://github.com/kvcache-ai/ktransformers.git
|
||||||
|
cd ktransformers
|
||||||
|
git submodule update --init
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
bash install.sh
|
||||||
|
pip uninstall triton pytorch-triton-xpu
|
||||||
|
pip install pytorch-triton-xpu==3.3.0 --extra-index-url https://download.pytorch.org/whl/xpu # to avoid potential triton import error
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running DeepSeek-R1 Models
|
||||||
|
|
||||||
|
### Configuration for 16B VRAM GPUs
|
||||||
|
Use our optimized configuration for constrained VRAM:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SYCL_CACHE_PERSISTENT=1
|
||||||
|
export ONEAPI_DEVICE_SELECTOR=level_zero:0
|
||||||
|
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||||
|
|
||||||
|
python ktransformers/local_chat.py \
|
||||||
|
--model_path deepseek-ai/DeepSeek-R1 \
|
||||||
|
--gguf_path <path_to_gguf_files> \
|
||||||
|
--optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \
|
||||||
|
--cpu_infer <cpu_cores + 1> \
|
||||||
|
--device xpu \
|
||||||
|
--max_new_tokens 200
|
||||||
|
```
|
||||||
|
|
||||||
|
## Known Limitations
|
||||||
|
- Serving function is not supported on Intel GPU platform for now
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
1. Best Known Config (BKC) to obtain best performance
|
||||||
|
|
||||||
|
To obtain best performance on Intel GPU platform, we recommand to lock GPU frequency and set CPU to performance mode by below settings.
|
||||||
|
```bash
|
||||||
|
echo "performance" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
|
||||||
|
echo 0 | sudo tee /sys/devices/system/cpu/cpu*/power/energy_perf_bias
|
||||||
|
# 2400 is max frequency for Arc A770
|
||||||
|
sudo xpu-smi config -d 0 -t 0 --frequencyrange 2400,2400
|
||||||
|
# 2850 is max frequency for Arc B580
|
||||||
|
# sudo xpu-smi config -d 0 -t 0 --frequencyrange 2850,2850
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Runtime error like `xpu/sycl/TensorCompareKernels.cpp:163: xxx. Aborted (core dumped)`
|
||||||
|
|
||||||
|
This error is mostly realted to GPU driver. If you meet such error, you could update your `intel-level-zero-gpu` to `1.3.29735.27-914~22.04` (which is a verified version by us) by below command.
|
||||||
|
```bash
|
||||||
|
wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \
|
||||||
|
sudo gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg
|
||||||
|
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy client" | \
|
||||||
|
sudo tee /etc/apt/sources.list.d/intel-gpu-jammy.list
|
||||||
|
sudo apt update
|
||||||
|
# or sudo apt update --allow-insecure-repositories
|
||||||
|
sudo apt install intel-level-zero-gpu=1.3.29735.27-914~22.04
|
||||||
|
```
|
||||||
|
|
||||||
|
3. `ImportError: cannot import name 'intel' from 'triton._C.libtriton'`
|
||||||
|
|
||||||
|
Installing Triton causes pytorch-triton-xpu to stop working. You can resolve the issue with following command:
|
||||||
|
```bash
|
||||||
|
pip uninstall triton pytorch-triton-xpu
|
||||||
|
# Reinstall correct version of pytorch-triton-xpu
|
||||||
|
pip install pytorch-triton-xpu==3.3.0 --index-url https://download.pytorch.org/whl/xpu
|
||||||
|
```
|
|
@ -63,18 +63,23 @@ def local_chat(
|
||||||
prompt_file : str | None = None,
|
prompt_file : str | None = None,
|
||||||
mode: str = "normal",
|
mode: str = "normal",
|
||||||
force_think: bool = False,
|
force_think: bool = False,
|
||||||
chunk_size: int = 8192
|
chunk_size: int = 8192,
|
||||||
|
device: str = "cuda"
|
||||||
):
|
):
|
||||||
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
Config().cpu_infer = cpu_infer
|
Config().cpu_infer = cpu_infer
|
||||||
|
if torch.xpu.is_available():
|
||||||
|
use_cuda_graph = False
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
if mode == 'long_context':
|
if mode == 'long_context':
|
||||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
|
elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||||
|
torch.set_default_dtype(torch.float16)
|
||||||
else:
|
else:
|
||||||
torch.set_default_dtype(config.torch_dtype)
|
torch.set_default_dtype(config.torch_dtype)
|
||||||
|
|
||||||
|
@ -109,7 +114,7 @@ def local_chat(
|
||||||
gguf_path = input(
|
gguf_path = input(
|
||||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||||
)
|
)
|
||||||
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
|
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, default_device=device)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||||
|
@ -172,12 +177,12 @@ def local_chat(
|
||||||
|
|
||||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
|
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
|
||||||
generated = prefill_and_generate(
|
generated = prefill_and_generate(
|
||||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
|
model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
|
||||||
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generated = prefill_and_generate(
|
generated = prefill_and_generate(
|
||||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
|
model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -293,7 +293,7 @@ class KGQACache(nn.Module):
|
||||||
self.v_caches = []
|
self.v_caches = []
|
||||||
|
|
||||||
|
|
||||||
def load(self, inference_context: sched_ext.InferenceContext):
|
def load(self, inference_context: "sched_ext.InferenceContext"):
|
||||||
print(self.config.num_hidden_layers)
|
print(self.config.num_hidden_layers)
|
||||||
for i in range(self.config.num_hidden_layers):
|
for i in range(self.config.num_hidden_layers):
|
||||||
self.k_caches.append(
|
self.k_caches.append(
|
||||||
|
|
|
@ -107,6 +107,7 @@ class DeepseekV2RMSNorm(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
|
|
|
@ -587,6 +587,100 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
def forward_xpu(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
else:
|
||||||
|
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||||
|
query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
compressed_kv, k_pe = torch.split(
|
||||||
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||||
|
kv = (
|
||||||
|
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||||
|
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
k_nope, value_states = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
kv_seq_len = value_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
position_embeddings = kwargs.get("position_embeddings", None)
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
key_states = torch.cat(
|
||||||
|
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
|
||||||
|
rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :],
|
||||||
|
key_states[:, :, :, self.qk_nope_head_dim:],
|
||||||
|
cos, sin, True)
|
||||||
|
else:
|
||||||
|
q_nope, q_pe = torch.split(
|
||||||
|
query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||||
|
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
|
||||||
|
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||||
|
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||||
|
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||||
|
|
||||||
|
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||||
|
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
||||||
|
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states.half(), value_states.half(), self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = None
|
||||||
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
query_states.half(), key_states, value_states,
|
||||||
|
attention_mask.half(), q_len == kv_seq_len, self.softmax_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||||
|
attn_output = self.o_proj(attn_output).to(hidden_states.dtype)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -598,7 +692,18 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if (os.name == 'nt'
|
if torch.xpu.is_available():
|
||||||
|
return self.forward_xpu(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif (os.name == 'nt'
|
||||||
or get_compute_capability() < 8
|
or get_compute_capability() < 8
|
||||||
or hidden_states.device.type == 'cpu'
|
or hidden_states.device.type == 'cpu'
|
||||||
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
|
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
|
||||||
|
|
|
@ -51,7 +51,10 @@ def generate_cuda_graphs(chunk_size: int) -> list:
|
||||||
|
|
||||||
return deduplicate_and_sort(base_list + multiples)
|
return deduplicate_and_sort(base_list + multiples)
|
||||||
#cuda_graphs = [Config().chunk_size]
|
#cuda_graphs = [Config().chunk_size]
|
||||||
|
if torch.cuda.is_available():
|
||||||
cuda_graphs = generate_cuda_graphs(Config().chunk_size)
|
cuda_graphs = generate_cuda_graphs(Config().chunk_size)
|
||||||
|
else:
|
||||||
|
cuda_graphs = 1
|
||||||
# class Base(BaseInjectedModule, ABC):
|
# class Base(BaseInjectedModule, ABC):
|
||||||
class KExpertsBase(ABC):
|
class KExpertsBase(ABC):
|
||||||
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
|
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
|
||||||
|
@ -177,6 +180,11 @@ class KExpertsCPU(KExpertsBase):
|
||||||
n_routed_experts = self.n_routed_experts
|
n_routed_experts = self.n_routed_experts
|
||||||
self.cpu_infer = KExpertsCPU.CPU_INFER
|
self.cpu_infer = KExpertsCPU.CPU_INFER
|
||||||
# n_routed_experts = len(self.orig_module)
|
# n_routed_experts = len(self.orig_module)
|
||||||
|
model_dtype = torch.get_default_dtype()
|
||||||
|
if torch.xpu.is_available() and model_dtype == torch.float16:
|
||||||
|
hidden_type = 1 # fp16
|
||||||
|
else:
|
||||||
|
hidden_type = 30 # bf16
|
||||||
if self.backend == "llamafile":
|
if self.backend == "llamafile":
|
||||||
moe_config = MOEConfig(
|
moe_config = MOEConfig(
|
||||||
n_routed_experts,
|
n_routed_experts,
|
||||||
|
@ -192,7 +200,7 @@ class KExpertsCPU(KExpertsBase):
|
||||||
self.gate_type,
|
self.gate_type,
|
||||||
self.up_type,
|
self.up_type,
|
||||||
self.down_type,
|
self.down_type,
|
||||||
30, # TODO: get from model.dtype
|
hidden_type, # TODO: get from model.dtype
|
||||||
)
|
)
|
||||||
self.moe = MOE(moe_config)
|
self.moe = MOE(moe_config)
|
||||||
elif self.backend == "AMXBF16":
|
elif self.backend == "AMXBF16":
|
||||||
|
@ -252,6 +260,10 @@ class KExpertsCPU(KExpertsBase):
|
||||||
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
|
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
|
||||||
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||||
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||||
|
if torch.xpu.is_available():
|
||||||
|
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype)
|
||||||
|
KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||||
|
else:
|
||||||
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||||
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||||
|
|
||||||
|
@ -285,9 +297,9 @@ class KExpertsCPU(KExpertsBase):
|
||||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
||||||
# generate, capture and run cuda graph
|
# generate, capture and run cuda graph
|
||||||
# print(expert_ids)
|
# print(expert_ids)
|
||||||
if bsz_tensor is None:
|
if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1):
|
||||||
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
|
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
|
||||||
if torch.cuda.is_current_stream_capturing():
|
if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
|
||||||
if cuda_graph_idx != -1:
|
if cuda_graph_idx != -1:
|
||||||
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
|
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
|
||||||
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
|
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
|
||||||
|
@ -307,6 +319,15 @@ class KExpertsCPU(KExpertsBase):
|
||||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||||
|
elif input_tensor.size(0)==1 and torch.xpu.is_available():
|
||||||
|
KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True)
|
||||||
|
KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True)
|
||||||
|
KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True)
|
||||||
|
# KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True)
|
||||||
|
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
|
||||||
|
self.cpu_infer.sync()
|
||||||
|
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||||
|
return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1)
|
||||||
else:
|
else:
|
||||||
input_tensor = input_tensor.contiguous().cpu()
|
input_tensor = input_tensor.contiguous().cpu()
|
||||||
expert_ids = expert_ids.contiguous().cpu()
|
expert_ids = expert_ids.contiguous().cpu()
|
||||||
|
@ -822,7 +843,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
|
||||||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
|
||||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
|
||||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
y_ = self.shared_experts(identity).squeeze(0)
|
y_ = self.shared_experts(identity).squeeze(0)
|
||||||
|
@ -922,7 +943,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
|
||||||
# only for generate phase
|
# only for generate phase
|
||||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
|
||||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
y_ = self.shared_experts(identity).squeeze(0)
|
y_ = self.shared_experts(identity).squeeze(0)
|
||||||
|
@ -1122,7 +1143,7 @@ class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
|
||||||
|
|
||||||
|
|
||||||
# only for generate phase
|
# only for generate phase
|
||||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
|
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
|
||||||
|
@ -1304,7 +1325,7 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
# only for generate phase
|
# only for generate phase
|
||||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||||
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||||
|
@ -1417,7 +1438,7 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
||||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
# only for generate phase
|
# only for generate phase
|
||||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||||
|
|
|
@ -183,3 +183,33 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
|
||||||
self.weight = None
|
self.weight = None
|
||||||
if self.e_score_correction_bias is not None:
|
if self.e_score_correction_bias is not None:
|
||||||
self.e_score_correction_bias = None
|
self.e_score_correction_bias = None
|
||||||
|
|
||||||
|
|
||||||
|
class KMoEGateIPEXLLM(KMoEGate):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
gguf_loader: GGUFLoader,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
orig_module: nn.Module = None,
|
||||||
|
generate_device: str = "xpu",
|
||||||
|
prefill_device: str = "xpu",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||||
|
KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||||
|
self.generate_device = generate_device
|
||||||
|
self.prefill_device = prefill_device
|
||||||
|
|
||||||
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
|
x = hidden_states.view(-1, hidden_states.size(-1))
|
||||||
|
logits = torch.nn.functional.linear(
|
||||||
|
x.type(torch.float32), self.orig_module.weight.type(torch.float32), None
|
||||||
|
)
|
||||||
|
scores = logits.sigmoid()
|
||||||
|
|
||||||
|
from ipex_llm.transformers.models.common import moe_group_topk
|
||||||
|
topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,
|
||||||
|
self.n_group, self.topk_group, self.top_k,
|
||||||
|
self.norm_topk_prob, self.routed_scaling_factor)
|
||||||
|
return topk_idx, topk_weight.to(x.dtype)
|
|
@ -30,6 +30,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
|
||||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
||||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||||
from ktransformers.util.custom_loader import GGUFLoader
|
from ktransformers.util.custom_loader import GGUFLoader
|
||||||
|
if not torch.xpu.is_available():
|
||||||
from flashinfer.norm import (
|
from flashinfer.norm import (
|
||||||
fused_add_rmsnorm,
|
fused_add_rmsnorm,
|
||||||
rmsnorm,
|
rmsnorm,
|
||||||
|
@ -194,3 +195,28 @@ class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
return self.weight * x.to(input_dtype), residual
|
return self.weight * x.to(input_dtype), residual
|
||||||
return self.weight * x.to(input_dtype)
|
return self.weight * x.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
|
||||||
|
def __init__(self,
|
||||||
|
key: str,
|
||||||
|
gguf_loader : GGUFLoader,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
orig_module: nn.Module,
|
||||||
|
prefill_device: str = "xpu",
|
||||||
|
generate_device: str = "xpu",
|
||||||
|
**kwargs):
|
||||||
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||||
|
self.orig_module.__init__(orig_module.hidden_size,
|
||||||
|
orig_module.variance_epsilon)
|
||||||
|
self.eps = orig_module.variance_epsilon
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||||
|
output = rms_norm_forward(self, x.float())
|
||||||
|
return output.to(x.dtype)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
BaseInjectedModule.load(self)
|
||||||
|
if self.weight.dtype != torch.float32:
|
||||||
|
self.weight = self.weight.float()
|
|
@ -14,10 +14,12 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
import ctypes
|
import ctypes
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
if not torch.xpu.is_available():
|
||||||
import KTransformersOps
|
import KTransformersOps
|
||||||
import vLLMMarlin
|
import vLLMMarlin
|
||||||
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
|
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
|
||||||
from ktransformers.util.utils import InferenceState
|
from ktransformers.util.utils import InferenceState
|
||||||
|
if not torch.xpu.is_available():
|
||||||
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
|
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
|
||||||
MarlinWorkspace,
|
MarlinWorkspace,
|
||||||
marlin_quantize,
|
marlin_quantize,
|
||||||
|
@ -778,6 +780,75 @@ class KLinearCPUInfer(KLinearBase):
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
class KLinearIPEXLLM(KLinearBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
gguf_loader: GGUFLoader,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
orig_module: nn.Module = None,
|
||||||
|
device: str = "xpu",
|
||||||
|
precision: str = "sym_int4",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||||
|
self.has_bias = False
|
||||||
|
self.dtype = torch.get_default_dtype()
|
||||||
|
self.weight = None
|
||||||
|
self.has_bias = False
|
||||||
|
self.precision = precision
|
||||||
|
self.qtype = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
|
||||||
|
dtype = x.dtype
|
||||||
|
out_device = x.device
|
||||||
|
from ipex_llm.transformers.models.common import linear_forward
|
||||||
|
x = linear_forward(x.half(), self.weight, self.qtype, self.out_features)
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
x = x + self.bias
|
||||||
|
x = x.to(dtype=dtype, device=out_device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||||
|
if self.loaded: return
|
||||||
|
if device is None: device = self.device
|
||||||
|
assert device.lower()[:3] == "xpu", "IPEX-LLM quantized linear only supports XPU device"
|
||||||
|
if w is None: w = self.load_weight(device=device)
|
||||||
|
|
||||||
|
if isinstance(w, nn.Parameter):
|
||||||
|
try:
|
||||||
|
weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||||
|
except:
|
||||||
|
weight = w.to(dtype=self.dtype).T
|
||||||
|
self.has_bias = False
|
||||||
|
elif isinstance(w, tuple):
|
||||||
|
try:
|
||||||
|
weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||||
|
except:
|
||||||
|
weight = w[0].to(dtype=self.dtype).T
|
||||||
|
self.bias = w[1].to(dtype=self.dtype)
|
||||||
|
self.has_bias = True
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid weight type")
|
||||||
|
weight = weight.to("cpu").float().transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
self.bias = self.bias.to(device)
|
||||||
|
|
||||||
|
# quantize linear weight
|
||||||
|
from ipex_llm.transformers.models.common import quantize_linear
|
||||||
|
paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision)
|
||||||
|
self.weight = paramsLowBit.to(device)
|
||||||
|
self.qtype = qtype
|
||||||
|
self.loaded = True
|
||||||
|
|
||||||
|
def unload(self):
|
||||||
|
if self.weight is not None:
|
||||||
|
self.weight = None
|
||||||
|
if self.has_bias:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
LINEAR_MAP = {
|
LINEAR_MAP = {
|
||||||
"KLinearMarlin": KLinearMarlin,
|
"KLinearMarlin": KLinearMarlin,
|
||||||
"KLinearTorch": KLinearTorch,
|
"KLinearTorch": KLinearTorch,
|
||||||
|
@ -785,6 +856,7 @@ LINEAR_MAP = {
|
||||||
"VLinearMarlin": VLinearMarlin,
|
"VLinearMarlin": VLinearMarlin,
|
||||||
"KLinearFP8": KLinearFP8,
|
"KLinearFP8": KLinearFP8,
|
||||||
"KLinearQ8": KLinearQ8,
|
"KLinearQ8": KLinearQ8,
|
||||||
|
"KLinearIPEXLLM": KLinearIPEXLLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||||
|
|
|
@ -647,6 +647,13 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
if inputs_embeds.device.type == "xpu" and position_ids is not None:
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
|
||||||
|
position_ids)
|
||||||
|
position_embeddings = (cos, sin)
|
||||||
|
else:
|
||||||
|
position_embeddings = None
|
||||||
|
|
||||||
if per_layer_prefill_flag:
|
if per_layer_prefill_flag:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
else:
|
else:
|
||||||
|
@ -737,6 +744,7 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
t5 = time.time()
|
t5 = time.time()
|
||||||
if per_layer_prefill_flag:
|
if per_layer_prefill_flag:
|
||||||
|
|
|
@ -103,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
|
||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
if child is not None:
|
if child is not None:
|
||||||
child_prefix = prefix + name + "."
|
child_prefix = prefix + name + "."
|
||||||
gen_optimize_config(child, out_data, rule_list, child_prefix)
|
gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device)
|
||||||
|
|
||||||
|
|
||||||
def translate_model_config(model_config: PretrainedConfig):
|
def translate_model_config(model_config: PretrainedConfig):
|
||||||
|
@ -127,8 +127,11 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
inject(module, optimize_config, model_config, weights_loader)
|
inject(module, optimize_config, model_config, weights_loader)
|
||||||
# pre load lm_head because its big inter result
|
# pre load lm_head because its big inter result
|
||||||
load_weights(module.lm_head, weights_loader, "lm_head.")
|
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
|
||||||
load_weights(module, weights_loader)
|
load_weights(module, weights_loader, device=default_device)
|
||||||
module.gguf_loader = weights_loader
|
module.gguf_loader = weights_loader
|
||||||
del_meta(module)
|
del_meta(module)
|
||||||
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif torch.xpu.is_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
generate_op: "KLinearIPEXLLM"
|
||||||
|
prefill_op: "KLinearIPEXLLM"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "xpu"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "xpu"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
device: "xpu"
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
|
@ -0,0 +1,81 @@
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
- match:
|
||||||
|
name: "^lm_head$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
generate_op: "KLinearIPEXLLM"
|
||||||
|
prefill_op: "KLinearIPEXLLM"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
generate_op: "KLinearIPEXLLM"
|
||||||
|
prefill_op: "KLinearIPEXLLM"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGateIPEXLLM
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu:0"
|
||||||
|
prefill_device: "xpu:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "xpu"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "xpu"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "xpu"
|
||||||
|
prefill_device: "xpu"
|
||||||
|
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
|
@ -24,6 +24,7 @@ from typing import Sequence
|
||||||
import os
|
import os
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import torch
|
import torch
|
||||||
|
if not torch.xpu.is_available():
|
||||||
import KTransformersOps
|
import KTransformersOps
|
||||||
import ctypes
|
import ctypes
|
||||||
import math
|
import math
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Sequence
|
||||||
import os
|
import os
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import torch
|
import torch
|
||||||
|
if not torch.xpu.is_available():
|
||||||
import KTransformersOps
|
import KTransformersOps
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||||
|
@ -459,7 +460,7 @@ class GGUFLoader(ModelLoader):
|
||||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
||||||
else:
|
else:
|
||||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||||
values = torch.from_numpy(values)
|
values = torch.from_numpy(values).to(device)
|
||||||
|
|
||||||
if ggml_name == "BF16":
|
if ggml_name == "BF16":
|
||||||
values = values.view(torch.bfloat16)
|
values = values.view(torch.bfloat16)
|
||||||
|
|
|
@ -27,6 +27,7 @@ from ktransformers.operators import base_operator
|
||||||
from ktransformers.models.custom_cache import StaticCache
|
from ktransformers.models.custom_cache import StaticCache
|
||||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||||
from ktransformers.util.textstream import TextStreamer
|
from ktransformers.util.textstream import TextStreamer
|
||||||
|
if not torch.xpu.is_available():
|
||||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
|
@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
|
||||||
return min_compute_capability_major
|
return min_compute_capability_major
|
||||||
else:
|
else:
|
||||||
return torch.cuda.get_device_properties(device)
|
return torch.cuda.get_device_properties(device)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
def set_module(model, submodule_key, module):
|
def set_module(model, submodule_key, module):
|
||||||
tokens = submodule_key.split('.')
|
tokens = submodule_key.split('.')
|
||||||
|
@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict):
|
||||||
all_device_list = list(all_device_list)
|
all_device_list = list(all_device_list)
|
||||||
return all_device_list
|
return all_device_list
|
||||||
|
|
||||||
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = ""):
|
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
|
||||||
prefix = prefix.replace("orig_module.", "")
|
prefix = prefix.replace("orig_module.", "")
|
||||||
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
||||||
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
|
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
|
||||||
|
@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
||||||
target_dtype = torch.get_default_dtype()
|
target_dtype = torch.get_default_dtype()
|
||||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||||
print(f"loading {translated_key} to {device}")
|
print(f"loading {translated_key} to {device}")
|
||||||
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif torch.xpu.is_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||||
set_param(module, name, weights)
|
set_param(module, name, weights)
|
||||||
del weights
|
del weights
|
||||||
|
@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
||||||
#print(load_config.tensor_file_map.keys())
|
#print(load_config.tensor_file_map.keys())
|
||||||
raise Exception(f"can't find {translated_key} in GGUF file!")
|
raise Exception(f"can't find {translated_key} in GGUF file!")
|
||||||
|
|
||||||
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix=''):
|
|
||||||
|
def sync_all_device(all_device_list):
|
||||||
|
for device in all_device_list:
|
||||||
|
if "cuda" in device.lower():
|
||||||
|
torch.cuda.synchronize(device)
|
||||||
|
elif "xpu" in device.lower():
|
||||||
|
torch.xpu.synchronize(device)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("The device {} is not available".format(device))
|
||||||
|
|
||||||
|
torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
|
||||||
|
|
||||||
|
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
|
||||||
#print(f"recursively loading weights {prefix}")
|
#print(f"recursively loading weights {prefix}")
|
||||||
if not isinstance(module, base_operator.BaseInjectedModule):
|
if not isinstance(module, base_operator.BaseInjectedModule):
|
||||||
load_cur_state_dict(module, gguf_loader, prefix)
|
load_cur_state_dict(module, gguf_loader, prefix, device=device)
|
||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
load_weights(child, gguf_loader, prefix+name+".")
|
load_weights(child, gguf_loader, prefix+name+".", device=device)
|
||||||
else:
|
else:
|
||||||
module.load()
|
module.load()
|
||||||
|
|
||||||
|
@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
torch._dynamo.config.suppress_errors = True
|
torch._dynamo.config.suppress_errors = True
|
||||||
batch_size, seq_length = inputs.shape
|
batch_size, seq_length = inputs.shape
|
||||||
device_map = model.gguf_loader.tensor_device_map
|
device_map = model.gguf_loader.tensor_device_map
|
||||||
torch_device = get_device('blk.0.self_attn', device_map)
|
torch_device = get_device('model.layers.0.self_attn', device_map)
|
||||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
|
||||||
inputs = inputs.to(torch_device)
|
inputs = inputs.to(torch_device)
|
||||||
all_cuda_device = get_all_used_cuda_device(device_map)
|
all_cuda_device = get_all_used_cuda_device(device_map)
|
||||||
|
|
||||||
|
@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
|
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
|
||||||
else:
|
else:
|
||||||
# custom_stream = torch.cuda.Stream()
|
# custom_stream = torch.cuda.Stream()
|
||||||
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_device(torch_device)
|
torch.cuda.set_device(torch_device)
|
||||||
|
elif torch.xpu.is_available():
|
||||||
|
torch.xpu.set_device(torch_device)
|
||||||
|
else:
|
||||||
|
RuntimeError("The device: {torch_device} is not available")
|
||||||
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
|
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
|
||||||
# with torch.cuda.stream(custom_stream):
|
# with torch.cuda.stream(custom_stream):
|
||||||
logits=model(inputs_embeds=inputs_embeds,
|
logits=model(inputs_embeds=inputs_embeds,
|
||||||
|
@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
return_dict=False, use_cache=True)[0]
|
return_dict=False, use_cache=True)[0]
|
||||||
if past_key_values != None:
|
if past_key_values != None and isinstance(past_key_values, StaticCache):
|
||||||
past_key_values.change_seq_length(1)
|
past_key_values.change_seq_length(1)
|
||||||
for device in all_cuda_device:
|
sync_all_device(all_cuda_device)
|
||||||
torch.cuda.synchronize(device)
|
|
||||||
#print(logits)
|
#print(logits)
|
||||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||||
if generation_config.do_sample:
|
if generation_config.do_sample:
|
||||||
|
@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_device(torch_device)
|
torch.cuda.set_device(torch_device)
|
||||||
|
elif torch.xpu.is_available():
|
||||||
|
torch.xpu.set_device(torch_device)
|
||||||
|
else:
|
||||||
|
RuntimeError("The device: {torch_device} is not available")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
stream = TextStreamer(tokenizer)
|
stream = TextStreamer(tokenizer)
|
||||||
if mode != 'long_context':
|
if torch.xpu.is_available():
|
||||||
|
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache
|
||||||
|
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
|
||||||
|
elif mode != 'long_context':
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||||
)
|
)
|
||||||
|
|
12
setup.py
12
setup.py
|
@ -39,6 +39,7 @@ try:
|
||||||
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
|
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
|
||||||
except ImportError:
|
except ImportError:
|
||||||
MUSA_HOME=None
|
MUSA_HOME=None
|
||||||
|
KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available()
|
||||||
|
|
||||||
with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"
|
with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"
|
||||||
|
|
||||||
|
@ -225,6 +226,8 @@ class VersionInfo:
|
||||||
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
|
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
|
||||||
elif ROCM_HOME is not None:
|
elif ROCM_HOME is not None:
|
||||||
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
|
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
|
||||||
|
elif torch.xpu.is_available():
|
||||||
|
backend_version = f"xpu"
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
|
raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
|
||||||
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
|
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
|
||||||
|
@ -495,6 +498,8 @@ class CMakeBuild(BuildExtension):
|
||||||
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
|
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
|
||||||
elif ROCM_HOME is not None:
|
elif ROCM_HOME is not None:
|
||||||
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
|
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
|
||||||
|
elif KTRANSFORMERS_BUILD_XPU:
|
||||||
|
cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.")
|
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.")
|
||||||
|
|
||||||
|
@ -620,9 +625,12 @@ elif MUSA_HOME is not None:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif torch.xpu.is_available(): #XPUExtension is not available now.
|
||||||
|
ops_module = None
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||||
|
|
||||||
|
if not torch.xpu.is_available():
|
||||||
ext_modules = [
|
ext_modules = [
|
||||||
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
|
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
|
||||||
ops_module,
|
ops_module,
|
||||||
|
@ -643,6 +651,10 @@ if with_balance:
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
|
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
ext_modules = [
|
||||||
|
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
|
||||||
|
]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=VersionInfo.PACKAGE_NAME,
|
name=VersionInfo.PACKAGE_NAME,
|
||||||
|
|
Loading…
Add table
Reference in a new issue