Merge pull request #1295 from rnwang04/xpu_support

Enable ktransformers on Intel GPU with local chat backend
This commit is contained in:
aubreyli 2025-05-14 20:58:35 +08:00 committed by GitHub
commit f7ee993fdc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 673 additions and 81 deletions

View file

@ -23,6 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
<h2 id="Updates">🔥 Updates</h2>
**May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2

View file

@ -41,6 +41,7 @@ option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
@ -303,6 +304,8 @@ elseif (UNIX)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
elseif (KTRANSFORMERS_USE_XPU)
add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
else()
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
@ -361,6 +364,7 @@ elseif(UNIX)
message(STATUS "Building for HIP")
elseif(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
elseif(KTRANSFORMERS_USE_XPU)
else()
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()

View file

@ -17,6 +17,7 @@
#include <queue>
#include <thread>
#include <vector>
#include <stdexcept>
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
@ -66,10 +67,14 @@
}
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
#else
throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma");
#endif
}
static void sync_(void* cpu_infer_ptr) {
@ -78,7 +83,11 @@
}
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
#else
throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma");
#endif
}
public:

View file

@ -9,7 +9,7 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#ifndef KTRANSFORMERS_USE_ROCM
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)
#include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h"

View file

@ -21,7 +21,7 @@ interface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified
Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.
<h2 id="Updates">🔥 Updates</h2>
* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./en/xpu.md)).
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./en/llama4.md)).
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./en/balance-serve.md)).
* **Mar 27, 2025**: Support Multi-concurrency.

117
doc/en/xpu.md Normal file
View file

@ -0,0 +1,117 @@
# Intel GPU Support for KTransformers (Beta)
## Introduction
### Overview
We are excited to introduce **Intel GPU support** in KTransformers (Beta release). This implementation has been tested and developed using Intel Xeon Scalable processors and Intel Arc GPU's (such as A770 and B580).
## Installation Guide
### 1. Install Intel GPU Driver
Begin by installing the GPU drivers for your Intel GPU:
- [Official GPU Installation Guide for Intel GPUs](https://dgpu-docs.intel.com/driver/overview.html)
> [!Important]
> Ensure that **Resizable BAR** is enabled in your system's BIOS before proceeding. This is essential for optimal GPU performance and to avoid potential issues such as `Bus error (core dumped)`. For detailed steps, please refer to the official guidance [here](https://www.intel.com/content/www/us/en/support/articles/000090831/graphics.html).
### 2. Set Up Conda Environment
We recommend using Miniconda3/Anaconda3 for environment management:
```bash
# Download Miniconda
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
# Create environment
conda create --name ktransformers python=3.11
conda activate ktransformers
# Install required libraries
conda install -c conda-forge libstdcxx-ng
# Verify GLIBCXX version (should include 3.4.32)
strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
```
> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
### 3. Install PyTorch and IPEX-LLM
Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm):
```bash
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu
pip uninstall torch torchvision torchaudio
pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7
pip install packaging ninja cpufeature numpy
pip uninstall intel-opencl-rt dpcpp-cpp-rt
```
### 4. Build ktransformers
```bash
# Clone repository
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
git submodule update --init
# Install dependencies
bash install.sh
pip uninstall triton pytorch-triton-xpu
pip install pytorch-triton-xpu==3.3.0 --extra-index-url https://download.pytorch.org/whl/xpu # to avoid potential triton import error
```
## Running DeepSeek-R1 Models
### Configuration for 16B VRAM GPUs
Use our optimized configuration for constrained VRAM:
```bash
export SYCL_CACHE_PERSISTENT=1
export ONEAPI_DEVICE_SELECTOR=level_zero:0
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
python ktransformers/local_chat.py \
--model_path deepseek-ai/DeepSeek-R1 \
--gguf_path <path_to_gguf_files> \
--optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \
--cpu_infer <cpu_cores + 1> \
--device xpu \
--max_new_tokens 200
```
## Known Limitations
- Serving function is not supported on Intel GPU platform for now
## Troubleshooting
1. Best Known Config (BKC) to obtain best performance
To obtain best performance on Intel GPU platform, we recommand to lock GPU frequency and set CPU to performance mode by below settings.
```bash
echo "performance" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
echo 0 | sudo tee /sys/devices/system/cpu/cpu*/power/energy_perf_bias
# 2400 is max frequency for Arc A770
sudo xpu-smi config -d 0 -t 0 --frequencyrange 2400,2400
# 2850 is max frequency for Arc B580
# sudo xpu-smi config -d 0 -t 0 --frequencyrange 2850,2850
```
2. Runtime error like `xpu/sycl/TensorCompareKernels.cpp:163: xxx. Aborted (core dumped)`
This error is mostly realted to GPU driver. If you meet such error, you could update your `intel-level-zero-gpu` to `1.3.29735.27-914~22.04` (which is a verified version by us) by below command.
```bash
wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \
sudo gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy client" | \
sudo tee /etc/apt/sources.list.d/intel-gpu-jammy.list
sudo apt update
# or sudo apt update --allow-insecure-repositories
sudo apt install intel-level-zero-gpu=1.3.29735.27-914~22.04
```
3. `ImportError: cannot import name 'intel' from 'triton._C.libtriton'`
Installing Triton causes pytorch-triton-xpu to stop working. You can resolve the issue with following command:
```bash
pip uninstall triton pytorch-triton-xpu
# Reinstall correct version of pytorch-triton-xpu
pip install pytorch-triton-xpu==3.3.0 --index-url https://download.pytorch.org/whl/xpu
```

View file

@ -63,18 +63,23 @@ def local_chat(
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
chunk_size: int = 8192
chunk_size: int = 8192,
device: str = "cuda"
):
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
if torch.xpu.is_available():
use_cuda_graph = False
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM":
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
@ -109,7 +114,7 @@ def local_chat(
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, default_device=device)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
@ -172,12 +177,12 @@ def local_chat(
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
)

View file

@ -213,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self.v_caches = []
def load(self, inference_context: "sched_ext.InferenceContext"):
def load(self, inference_context: "sched_ext.InferenceContext"):
for i in range(self.config.num_hidden_layers):
self.k_caches.append(
@ -293,7 +293,7 @@ class KGQACache(nn.Module):
self.v_caches = []
def load(self, inference_context: sched_ext.InferenceContext):
def load(self, inference_context: "sched_ext.InferenceContext"):
print(self.config.num_hidden_layers)
for i in range(self.config.num_hidden_layers):
self.k_caches.append(

View file

@ -107,6 +107,7 @@ class DeepseekV2RMSNorm(nn.Module):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.hidden_size = hidden_size
def forward(self, hidden_states):
input_dtype = hidden_states.dtype

View file

@ -587,6 +587,100 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
return attn_output, None, past_key_value
def forward_xpu(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
position_embeddings = kwargs.get("position_embeddings", None)
if position_embeddings is not None:
cos, sin = position_embeddings
key_states = torch.cat(
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
dim=-1
)
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :],
key_states[:, :, :, self.qk_nope_head_dim:],
cos, sin, True)
else:
q_nope, q_pe = torch.split(
query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states.half(), value_states.half(), self.layer_idx, cache_kwargs
)
attn_weights = None
from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states.half(), key_states, value_states,
attention_mask.half(), q_len == kv_seq_len, self.softmax_scale
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output).to(hidden_states.dtype)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
@ -598,10 +692,21 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if (os.name == 'nt'
or get_compute_capability() < 8
or hidden_states.device.type == 'cpu'
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
if torch.xpu.is_available():
return self.forward_xpu(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
elif (os.name == 'nt'
or get_compute_capability() < 8
or hidden_states.device.type == 'cpu'
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
return self.forward_windows(
hidden_states,
attention_mask,

View file

@ -51,7 +51,10 @@ def generate_cuda_graphs(chunk_size: int) -> list:
return deduplicate_and_sort(base_list + multiples)
#cuda_graphs = [Config().chunk_size]
cuda_graphs = generate_cuda_graphs(Config().chunk_size)
if torch.cuda.is_available():
cuda_graphs = generate_cuda_graphs(Config().chunk_size)
else:
cuda_graphs = 1
# class Base(BaseInjectedModule, ABC):
class KExpertsBase(ABC):
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
@ -177,6 +180,11 @@ class KExpertsCPU(KExpertsBase):
n_routed_experts = self.n_routed_experts
self.cpu_infer = KExpertsCPU.CPU_INFER
# n_routed_experts = len(self.orig_module)
model_dtype = torch.get_default_dtype()
if torch.xpu.is_available() and model_dtype == torch.float16:
hidden_type = 1 # fp16
else:
hidden_type = 30 # bf16
if self.backend == "llamafile":
moe_config = MOEConfig(
n_routed_experts,
@ -192,7 +200,7 @@ class KExpertsCPU(KExpertsBase):
self.gate_type,
self.up_type,
self.down_type,
30, # TODO: get from model.dtype
hidden_type, # TODO: get from model.dtype
)
self.moe = MOE(moe_config)
elif self.backend == "AMXBF16":
@ -252,8 +260,12 @@ class KExpertsCPU(KExpertsBase):
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
if torch.xpu.is_available():
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype)
KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
else:
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
if bsz_tensor is None:
@ -285,9 +297,9 @@ class KExpertsCPU(KExpertsBase):
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
# generate, capture and run cuda graph
# print(expert_ids)
if bsz_tensor is None:
if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1):
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
if torch.cuda.is_current_stream_capturing():
if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
@ -307,6 +319,15 @@ class KExpertsCPU(KExpertsBase):
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
elif input_tensor.size(0)==1 and torch.xpu.is_available():
KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True)
# KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True)
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
self.cpu_infer.sync()
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1)
else:
input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu()
@ -822,7 +843,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@ -922,7 +943,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@ -1122,7 +1143,7 @@ class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
@ -1304,7 +1325,7 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
@ -1417,7 +1438,7 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_

View file

@ -182,4 +182,34 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
self.e_score_correction_bias = None
class KMoEGateIPEXLLM(KMoEGate):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "xpu",
prefill_device: str = "xpu",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
def forward(self, hidden_states) -> torch.Tensor:
x = hidden_states.view(-1, hidden_states.size(-1))
logits = torch.nn.functional.linear(
x.type(torch.float32), self.orig_module.weight.type(torch.float32), None
)
scores = logits.sigmoid()
from ipex_llm.transformers.models.common import moe_group_topk
topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,
self.n_group, self.topk_group, self.top_k,
self.norm_topk_prob, self.routed_scaling_factor)
return topk_idx, topk_weight.to(x.dtype)

View file

@ -30,10 +30,11 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
from flashinfer.norm import (
fused_add_rmsnorm,
rmsnorm,
)
if not torch.xpu.is_available():
from flashinfer.norm import (
fused_add_rmsnorm,
rmsnorm,
)
logger = logging.getLogger(__name__)
@ -193,4 +194,29 @@ class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
x = x * torch.rsqrt(variance + self.variance_epsilon)
if residual is not None:
return self.weight * x.to(input_dtype), residual
return self.weight * x.to(input_dtype)
return self.weight * x.to(input_dtype)
class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "xpu",
generate_device: str = "xpu",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.hidden_size,
orig_module.variance_epsilon)
self.eps = orig_module.variance_epsilon
def forward(self, x: torch.Tensor) -> torch.Tensor:
from ipex_llm.transformers.models.common import rms_norm_forward
output = rms_norm_forward(self, x.float())
return output.to(x.dtype)
def load(self):
BaseInjectedModule.load(self)
if self.weight.dtype != torch.float32:
self.weight = self.weight.float()

View file

@ -14,18 +14,20 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import ctypes
import torch
from torch import Tensor, nn
import KTransformersOps
import vLLMMarlin
if not torch.xpu.is_available():
import KTransformersOps
import vLLMMarlin
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
from ktransformers.util.utils import InferenceState
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
vllm_marlin_quantize
)
if not torch.xpu.is_available():
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
vllm_marlin_quantize
)
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
@ -778,6 +780,75 @@ class KLinearCPUInfer(KLinearBase):
if self.has_bias:
self.bias = None
class KLinearIPEXLLM(KLinearBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "xpu",
precision: str = "sym_int4",
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.weight = None
self.has_bias = False
self.precision = precision
self.qtype = None
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
dtype = x.dtype
out_device = x.device
from ipex_llm.transformers.models.common import linear_forward
x = linear_forward(x.half(), self.weight, self.qtype, self.out_features)
if self.has_bias:
x = x + self.bias
x = x.to(dtype=dtype, device=out_device)
return x
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device
assert device.lower()[:3] == "xpu", "IPEX-LLM quantized linear only supports XPU device"
if w is None: w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
try:
weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
except:
weight = w.to(dtype=self.dtype).T
self.has_bias = False
elif isinstance(w, tuple):
try:
weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
except:
weight = w[0].to(dtype=self.dtype).T
self.bias = w[1].to(dtype=self.dtype)
self.has_bias = True
else:
raise ValueError("Invalid weight type")
weight = weight.to("cpu").float().transpose(0, 1).contiguous()
if self.has_bias:
self.bias = self.bias.to(device)
# quantize linear weight
from ipex_llm.transformers.models.common import quantize_linear
paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision)
self.weight = paramsLowBit.to(device)
self.qtype = qtype
self.loaded = True
def unload(self):
if self.weight is not None:
self.weight = None
if self.has_bias:
self.bias = None
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
@ -785,6 +856,7 @@ LINEAR_MAP = {
"VLinearMarlin": VLinearMarlin,
"KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
"KLinearIPEXLLM": KLinearIPEXLLM,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):

View file

@ -647,6 +647,13 @@ class KDeepseekV2Model(BaseInjectedModule):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if inputs_embeds.device.type == "xpu" and position_ids is not None:
cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
position_ids)
position_embeddings = (cos, sin)
else:
position_embeddings = None
if per_layer_prefill_flag:
causal_mask = None
else:
@ -737,6 +744,7 @@ class KDeepseekV2Model(BaseInjectedModule):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
t5 = time.time()
if per_layer_prefill_flag:

View file

@ -103,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + "."
gen_optimize_config(child, out_data, rule_list, child_prefix)
gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device)
def translate_model_config(model_config: PretrainedConfig):
@ -127,8 +127,11 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
with torch.device("meta"):
inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, weights_loader, "lm_head.")
load_weights(module, weights_loader)
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
load_weights(module, weights_loader, device=default_device)
module.gguf_loader = weights_loader
del_meta(module)
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()

View file

@ -0,0 +1,64 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model\\.layers\\..*" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
generate_op: "KLinearIPEXLLM"
prefill_op: "KLinearIPEXLLM"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "xpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "xpu"
recursive: False # don't recursively inject submodules of this module
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm
replace:
class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
device: "xpu"
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -0,0 +1,81 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
generate_op: "KLinearIPEXLLM"
prefill_op: "KLinearIPEXLLM"
- match:
name: "^model\\.layers\\..*" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
generate_op: "KLinearIPEXLLM"
prefill_op: "KLinearIPEXLLM"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace:
class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateIPEXLLM
kwargs:
generate_device: "xpu:0"
prefill_device: "xpu:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "xpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "xpu"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -24,7 +24,8 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
import KTransformersOps
if not torch.xpu.is_available():
import KTransformersOps
import ctypes
import math

View file

@ -7,7 +7,8 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
import KTransformersOps
if not torch.xpu.is_available():
import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from ktransformers.util.custom_gguf import *
@ -459,7 +460,7 @@ class GGUFLoader(ModelLoader):
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
values = torch.from_numpy(values).to(device)
if ggml_name == "BF16":
values = values.view(torch.bfloat16)

View file

@ -27,7 +27,8 @@ from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
if not torch.xpu.is_available():
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
warm_uped = False
@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
return min_compute_capability_major
else:
return torch.cuda.get_device_properties(device)
else:
return 0
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list = list(all_device_list)
return all_device_list
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = ""):
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
#print(load_config.tensor_file_map.keys())
raise Exception(f"can't find {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix=''):
def sync_all_device(all_device_list):
for device in all_device_list:
if "cuda" in device.lower():
torch.cuda.synchronize(device)
elif "xpu" in device.lower():
torch.xpu.synchronize(device)
else:
raise RuntimeError("The device {} is not available".format(device))
torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
#print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix)
load_cur_state_dict(module, gguf_loader, prefix, device=device)
for name, child in module._modules.items():
load_weights(child, gguf_loader, prefix+name+".")
load_weights(child, gguf_loader, prefix+name+".", device=device)
else:
module.load()
@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape
device_map = model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
torch_device = get_device('model.layers.0.self_attn', device_map)
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
torch.cuda.set_device(torch_device)
if torch.cuda.is_available():
torch.cuda.set_device(torch_device)
elif torch.xpu.is_available():
torch.xpu.set_device(torch_device)
else:
RuntimeError("The device: {torch_device} is not available")
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
if past_key_values != None:
if past_key_values != None and isinstance(past_key_values, StaticCache):
past_key_values.change_seq_length(1)
for device in all_cuda_device:
torch.cuda.synchronize(device)
sync_all_device(all_cuda_device)
#print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
return logits
torch.cuda.set_device(torch_device)
if torch.cuda.is_available():
torch.cuda.set_device(torch_device)
elif torch.xpu.is_available():
torch.xpu.set_device(torch_device)
else:
RuntimeError("The device: {torch_device} is not available")
with torch.no_grad():
stream = TextStreamer(tokenizer)
if mode != 'long_context':
if torch.xpu.is_available():
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
elif mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)

View file

@ -39,7 +39,8 @@ try:
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None
KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available()
with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"
class CpuInstructInfo:
@ -225,6 +226,8 @@ class VersionInfo:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
elif ROCM_HOME is not None:
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
elif torch.xpu.is_available():
backend_version = f"xpu"
else:
raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
@ -495,6 +498,8 @@ class CMakeBuild(BuildExtension):
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
elif ROCM_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
elif KTRANSFORMERS_BUILD_XPU:
cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"]
else:
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.")
@ -620,29 +625,36 @@ elif MUSA_HOME is not None:
]
}
)
elif torch.xpu.is_available(): #XPUExtension is not available now.
ops_module = None
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
ops_module,
CUDAExtension(
'vLLMMarlin', [
'csrc/custom_marlin/binding.cpp',
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
},
)
]
if with_balance:
print("using balance_serve")
ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
if not torch.xpu.is_available():
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
ops_module,
CUDAExtension(
'vLLMMarlin', [
'csrc/custom_marlin/binding.cpp',
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
},
)
]
if with_balance:
print("using balance_serve")
ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
else:
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
]
setup(
name=VersionInfo.PACKAGE_NAME,