diff --git a/README.md b/README.md
index 6ad3a59..172694a 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+ **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).
+
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2
diff --git a/csrc/ktransformers_ext/CMakeLists.txt b/csrc/ktransformers_ext/CMakeLists.txt
index 217de78..0ed4ef4 100644
--- a/csrc/ktransformers_ext/CMakeLists.txt
+++ b/csrc/ktransformers_ext/CMakeLists.txt
@@ -41,6 +41,7 @@ option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
+option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
@@ -303,6 +304,8 @@ elseif (UNIX)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
+ elseif (KTRANSFORMERS_USE_XPU)
+ add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
else()
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
@@ -361,6 +364,7 @@ elseif(UNIX)
message(STATUS "Building for HIP")
elseif(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
+ elseif(KTRANSFORMERS_USE_XPU)
else()
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
diff --git a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h
index 9c7e781..7b1d898 100644
--- a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h
+++ b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h
@@ -17,6 +17,7 @@
#include
#include
#include
+ #include
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
@@ -66,10 +67,14 @@
}
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) {
+ #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
+ #else
+ throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma");
+ #endif
}
static void sync_(void* cpu_infer_ptr) {
@@ -78,7 +83,11 @@
}
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
+ #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
+ #else
+ throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma");
+ #endif
}
public:
diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp
index 2767679..f0aeaa5 100644
--- a/csrc/ktransformers_ext/ext_bindings.cpp
+++ b/csrc/ktransformers_ext/ext_bindings.cpp
@@ -9,7 +9,7 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
-#ifndef KTRANSFORMERS_USE_ROCM
+#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)
#include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h"
diff --git a/doc/README.md b/doc/README.md
index 05df2d3..1e994aa 100644
--- a/doc/README.md
+++ b/doc/README.md
@@ -21,7 +21,7 @@ interface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified
Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.
🔥 Updates
-
+* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./en/xpu.md)).
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./en/llama4.md)).
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./en/balance-serve.md)).
* **Mar 27, 2025**: Support Multi-concurrency.
diff --git a/doc/en/xpu.md b/doc/en/xpu.md
new file mode 100644
index 0000000..4f71ae7
--- /dev/null
+++ b/doc/en/xpu.md
@@ -0,0 +1,117 @@
+# Intel GPU Support for KTransformers (Beta)
+
+## Introduction
+
+### Overview
+We are excited to introduce **Intel GPU support** in KTransformers (Beta release). This implementation has been tested and developed using Intel Xeon Scalable processors and Intel Arc GPU's (such as A770 and B580).
+
+## Installation Guide
+
+### 1. Install Intel GPU Driver
+Begin by installing the GPU drivers for your Intel GPU:
+- [Official GPU Installation Guide for Intel GPUs](https://dgpu-docs.intel.com/driver/overview.html)
+
+> [!Important]
+> Ensure that **Resizable BAR** is enabled in your system's BIOS before proceeding. This is essential for optimal GPU performance and to avoid potential issues such as `Bus error (core dumped)`. For detailed steps, please refer to the official guidance [here](https://www.intel.com/content/www/us/en/support/articles/000090831/graphics.html).
+
+### 2. Set Up Conda Environment
+We recommend using Miniconda3/Anaconda3 for environment management:
+
+```bash
+# Download Miniconda
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
+
+# Create environment
+conda create --name ktransformers python=3.11
+conda activate ktransformers
+
+# Install required libraries
+conda install -c conda-forge libstdcxx-ng
+
+# Verify GLIBCXX version (should include 3.4.32)
+strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
+```
+
+> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
+
+### 3. Install PyTorch and IPEX-LLM
+Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm):
+
+```bash
+pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu
+pip uninstall torch torchvision torchaudio
+pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7
+pip install packaging ninja cpufeature numpy
+pip uninstall intel-opencl-rt dpcpp-cpp-rt
+```
+
+### 4. Build ktransformers
+
+```bash
+# Clone repository
+git clone https://github.com/kvcache-ai/ktransformers.git
+cd ktransformers
+git submodule update --init
+
+# Install dependencies
+bash install.sh
+pip uninstall triton pytorch-triton-xpu
+pip install pytorch-triton-xpu==3.3.0 --extra-index-url https://download.pytorch.org/whl/xpu # to avoid potential triton import error
+```
+
+## Running DeepSeek-R1 Models
+
+### Configuration for 16B VRAM GPUs
+Use our optimized configuration for constrained VRAM:
+
+```bash
+export SYCL_CACHE_PERSISTENT=1
+export ONEAPI_DEVICE_SELECTOR=level_zero:0
+export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+
+python ktransformers/local_chat.py \
+ --model_path deepseek-ai/DeepSeek-R1 \
+ --gguf_path \
+ --optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \
+ --cpu_infer \
+ --device xpu \
+ --max_new_tokens 200
+```
+
+## Known Limitations
+- Serving function is not supported on Intel GPU platform for now
+
+## Troubleshooting
+1. Best Known Config (BKC) to obtain best performance
+
+To obtain best performance on Intel GPU platform, we recommand to lock GPU frequency and set CPU to performance mode by below settings.
+```bash
+echo "performance" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
+echo 0 | sudo tee /sys/devices/system/cpu/cpu*/power/energy_perf_bias
+# 2400 is max frequency for Arc A770
+sudo xpu-smi config -d 0 -t 0 --frequencyrange 2400,2400
+# 2850 is max frequency for Arc B580
+# sudo xpu-smi config -d 0 -t 0 --frequencyrange 2850,2850
+```
+
+2. Runtime error like `xpu/sycl/TensorCompareKernels.cpp:163: xxx. Aborted (core dumped)`
+
+This error is mostly realted to GPU driver. If you meet such error, you could update your `intel-level-zero-gpu` to `1.3.29735.27-914~22.04` (which is a verified version by us) by below command.
+```bash
+wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \
+sudo gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg
+echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy client" | \
+sudo tee /etc/apt/sources.list.d/intel-gpu-jammy.list
+sudo apt update
+# or sudo apt update --allow-insecure-repositories
+sudo apt install intel-level-zero-gpu=1.3.29735.27-914~22.04
+```
+
+3. `ImportError: cannot import name 'intel' from 'triton._C.libtriton'`
+
+Installing Triton causes pytorch-triton-xpu to stop working. You can resolve the issue with following command:
+```bash
+pip uninstall triton pytorch-triton-xpu
+# Reinstall correct version of pytorch-triton-xpu
+pip install pytorch-triton-xpu==3.3.0 --index-url https://download.pytorch.org/whl/xpu
+```
\ No newline at end of file
diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py
index 928de48..173ce07 100644
--- a/ktransformers/local_chat.py
+++ b/ktransformers/local_chat.py
@@ -63,18 +63,23 @@ def local_chat(
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
- chunk_size: int = 8192
+ chunk_size: int = 8192,
+ device: str = "cuda"
):
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
+ if torch.xpu.is_available():
+ use_cuda_graph = False
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
+ elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM":
+ torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
@@ -109,7 +114,7 @@ def local_chat(
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
- optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
+ optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, default_device=device)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
@@ -172,12 +177,12 @@ def local_chat(
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate(
- model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
+ model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
)
else:
generated = prefill_and_generate(
- model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
+ model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
)
diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py
index c2901ac..1121b4a 100644
--- a/ktransformers/models/custom_cache.py
+++ b/ktransformers/models/custom_cache.py
@@ -213,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self.v_caches = []
- def load(self, inference_context: "sched_ext.InferenceContext"):
+ def load(self, inference_context: "sched_ext.InferenceContext"):
for i in range(self.config.num_hidden_layers):
self.k_caches.append(
@@ -293,7 +293,7 @@ class KGQACache(nn.Module):
self.v_caches = []
- def load(self, inference_context: sched_ext.InferenceContext):
+ def load(self, inference_context: "sched_ext.InferenceContext"):
print(self.config.num_hidden_layers)
for i in range(self.config.num_hidden_layers):
self.k_caches.append(
diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py
index e14a521..f6845ec 100644
--- a/ktransformers/models/modeling_deepseek.py
+++ b/ktransformers/models/modeling_deepseek.py
@@ -107,6 +107,7 @@ class DeepseekV2RMSNorm(nn.Module):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
+ self.hidden_size = hidden_size
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py
index caceb98..41dbf5a 100644
--- a/ktransformers/operators/attention.py
+++ b/ktransformers/operators/attention.py
@@ -587,6 +587,100 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
return attn_output, None, past_key_value
+ def forward_xpu(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.q_lora_rank is None:
+ q = self.q_proj(hidden_states)
+ else:
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+ query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ compressed_kv, k_pe = torch.split(
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+ )
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
+ kv = (
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
+ .transpose(1, 2)
+ )
+
+ k_nope, value_states = torch.split(
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
+ )
+ kv_seq_len = value_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ position_embeddings = kwargs.get("position_embeddings", None)
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ key_states = torch.cat(
+ [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
+ dim=-1
+ )
+ from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
+ rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :],
+ key_states[:, :, :, self.qk_nope_head_dim:],
+ cos, sin, True)
+ else:
+ q_nope, q_pe = torch.split(
+ query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+ cos, sin = self.rotary_emb(q_pe, position_ids)
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
+
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(
+ key_states.half(), value_states.half(), self.layer_idx, cache_kwargs
+ )
+
+ attn_weights = None
+ from ipex_llm.transformers.models.common import scaled_dot_product_attention
+ attn_output = scaled_dot_product_attention(
+ query_states.half(), key_states, value_states,
+ attention_mask.half(), q_len == kv_seq_len, self.softmax_scale
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
+ attn_output = self.o_proj(attn_output).to(hidden_states.dtype)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -598,10 +692,21 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if (os.name == 'nt'
- or get_compute_capability() < 8
- or hidden_states.device.type == 'cpu'
- or device_manager.gpu_vendor != GPUVendor.NVIDIA):
+ if torch.xpu.is_available():
+ return self.forward_xpu(
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ output_attentions,
+ use_cache,
+ cache_position,
+ **kwargs,
+ )
+ elif (os.name == 'nt'
+ or get_compute_capability() < 8
+ or hidden_states.device.type == 'cpu'
+ or device_manager.gpu_vendor != GPUVendor.NVIDIA):
return self.forward_windows(
hidden_states,
attention_mask,
diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py
index 279439d..d7d1926 100644
--- a/ktransformers/operators/experts.py
+++ b/ktransformers/operators/experts.py
@@ -51,7 +51,10 @@ def generate_cuda_graphs(chunk_size: int) -> list:
return deduplicate_and_sort(base_list + multiples)
#cuda_graphs = [Config().chunk_size]
-cuda_graphs = generate_cuda_graphs(Config().chunk_size)
+if torch.cuda.is_available():
+ cuda_graphs = generate_cuda_graphs(Config().chunk_size)
+else:
+ cuda_graphs = 1
# class Base(BaseInjectedModule, ABC):
class KExpertsBase(ABC):
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
@@ -177,6 +180,11 @@ class KExpertsCPU(KExpertsBase):
n_routed_experts = self.n_routed_experts
self.cpu_infer = KExpertsCPU.CPU_INFER
# n_routed_experts = len(self.orig_module)
+ model_dtype = torch.get_default_dtype()
+ if torch.xpu.is_available() and model_dtype == torch.float16:
+ hidden_type = 1 # fp16
+ else:
+ hidden_type = 30 # bf16
if self.backend == "llamafile":
moe_config = MOEConfig(
n_routed_experts,
@@ -192,7 +200,7 @@ class KExpertsCPU(KExpertsBase):
self.gate_type,
self.up_type,
self.down_type,
- 30, # TODO: get from model.dtype
+ hidden_type, # TODO: get from model.dtype
)
self.moe = MOE(moe_config)
elif self.backend == "AMXBF16":
@@ -252,8 +260,12 @@ class KExpertsCPU(KExpertsBase):
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
- KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
- KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
+ if torch.xpu.is_available():
+ KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype)
+ KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
+ else:
+ KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
+ KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
if bsz_tensor is None:
@@ -285,9 +297,9 @@ class KExpertsCPU(KExpertsBase):
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
# generate, capture and run cuda graph
# print(expert_ids)
- if bsz_tensor is None:
+ if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1):
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
- if torch.cuda.is_current_stream_capturing():
+ if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
@@ -307,6 +319,15 @@ class KExpertsCPU(KExpertsBase):
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
+ elif input_tensor.size(0)==1 and torch.xpu.is_available():
+ KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True)
+ KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True)
+ KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True)
+ # KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True)
+ self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
+ self.cpu_infer.sync()
+ KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
+ return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1)
else:
input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu()
@@ -822,7 +843,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
+ if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@@ -922,7 +943,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
- if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
+ if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@@ -1122,7 +1143,7 @@ class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
# only for generate phase
- if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
+ if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
@@ -1304,7 +1325,7 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
- if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
+ if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
@@ -1417,7 +1438,7 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
- if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
+ if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py
index a6f95ac..f5f96c1 100644
--- a/ktransformers/operators/gate.py
+++ b/ktransformers/operators/gate.py
@@ -182,4 +182,34 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
- self.e_score_correction_bias = None
\ No newline at end of file
+ self.e_score_correction_bias = None
+
+
+class KMoEGateIPEXLLM(KMoEGate):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module = None,
+ generate_device: str = "xpu",
+ prefill_device: str = "xpu",
+ **kwargs,
+ ):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
+ KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
+ self.generate_device = generate_device
+ self.prefill_device = prefill_device
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ x = hidden_states.view(-1, hidden_states.size(-1))
+ logits = torch.nn.functional.linear(
+ x.type(torch.float32), self.orig_module.weight.type(torch.float32), None
+ )
+ scores = logits.sigmoid()
+
+ from ipex_llm.transformers.models.common import moe_group_topk
+ topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,
+ self.n_group, self.topk_group, self.top_k,
+ self.norm_topk_prob, self.routed_scaling_factor)
+ return topk_idx, topk_weight.to(x.dtype)
\ No newline at end of file
diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py
index 22d580b..6d616d1 100644
--- a/ktransformers/operators/layernorm.py
+++ b/ktransformers/operators/layernorm.py
@@ -30,10 +30,11 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
-from flashinfer.norm import (
- fused_add_rmsnorm,
- rmsnorm,
-)
+if not torch.xpu.is_available():
+ from flashinfer.norm import (
+ fused_add_rmsnorm,
+ rmsnorm,
+ )
logger = logging.getLogger(__name__)
@@ -193,4 +194,29 @@ class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
x = x * torch.rsqrt(variance + self.variance_epsilon)
if residual is not None:
return self.weight * x.to(input_dtype), residual
- return self.weight * x.to(input_dtype)
\ No newline at end of file
+ return self.weight * x.to(input_dtype)
+
+
+class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ prefill_device: str = "xpu",
+ generate_device: str = "xpu",
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
+ self.orig_module.__init__(orig_module.hidden_size,
+ orig_module.variance_epsilon)
+ self.eps = orig_module.variance_epsilon
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ from ipex_llm.transformers.models.common import rms_norm_forward
+ output = rms_norm_forward(self, x.float())
+ return output.to(x.dtype)
+
+ def load(self):
+ BaseInjectedModule.load(self)
+ if self.weight.dtype != torch.float32:
+ self.weight = self.weight.float()
\ No newline at end of file
diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py
index 2b12b15..9ce45d1 100644
--- a/ktransformers/operators/linear.py
+++ b/ktransformers/operators/linear.py
@@ -14,18 +14,20 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import ctypes
import torch
from torch import Tensor, nn
-import KTransformersOps
-import vLLMMarlin
+if not torch.xpu.is_available():
+ import KTransformersOps
+ import vLLMMarlin
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
from ktransformers.util.utils import InferenceState
-from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
- MarlinWorkspace,
- marlin_quantize,
- GPTQ_MARLIN_MIN_THREAD_N,
- GPTQ_MARLIN_MIN_THREAD_K,
- GPTQ_MARLIN_MAX_PARALLEL,
- vllm_marlin_quantize
-)
+if not torch.xpu.is_available():
+ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
+ MarlinWorkspace,
+ marlin_quantize,
+ GPTQ_MARLIN_MIN_THREAD_N,
+ GPTQ_MARLIN_MIN_THREAD_K,
+ GPTQ_MARLIN_MAX_PARALLEL,
+ vllm_marlin_quantize
+ )
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
@@ -778,6 +780,75 @@ class KLinearCPUInfer(KLinearBase):
if self.has_bias:
self.bias = None
+class KLinearIPEXLLM(KLinearBase):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module = None,
+ device: str = "xpu",
+ precision: str = "sym_int4",
+ **kwargs,
+ ):
+ super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
+ self.has_bias = False
+ self.dtype = torch.get_default_dtype()
+ self.weight = None
+ self.has_bias = False
+ self.precision = precision
+ self.qtype = None
+
+ def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
+ dtype = x.dtype
+ out_device = x.device
+ from ipex_llm.transformers.models.common import linear_forward
+ x = linear_forward(x.half(), self.weight, self.qtype, self.out_features)
+
+ if self.has_bias:
+ x = x + self.bias
+ x = x.to(dtype=dtype, device=out_device)
+ return x
+
+ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
+ if self.loaded: return
+ if device is None: device = self.device
+ assert device.lower()[:3] == "xpu", "IPEX-LLM quantized linear only supports XPU device"
+ if w is None: w = self.load_weight(device=device)
+
+ if isinstance(w, nn.Parameter):
+ try:
+ weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
+ except:
+ weight = w.to(dtype=self.dtype).T
+ self.has_bias = False
+ elif isinstance(w, tuple):
+ try:
+ weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
+ except:
+ weight = w[0].to(dtype=self.dtype).T
+ self.bias = w[1].to(dtype=self.dtype)
+ self.has_bias = True
+ else:
+ raise ValueError("Invalid weight type")
+ weight = weight.to("cpu").float().transpose(0, 1).contiguous()
+
+ if self.has_bias:
+ self.bias = self.bias.to(device)
+
+ # quantize linear weight
+ from ipex_llm.transformers.models.common import quantize_linear
+ paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision)
+ self.weight = paramsLowBit.to(device)
+ self.qtype = qtype
+ self.loaded = True
+
+ def unload(self):
+ if self.weight is not None:
+ self.weight = None
+ if self.has_bias:
+ self.bias = None
+
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
@@ -785,6 +856,7 @@ LINEAR_MAP = {
"VLinearMarlin": VLinearMarlin,
"KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
+ "KLinearIPEXLLM": KLinearIPEXLLM,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):
diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py
index 4aa223d..8299d4c 100644
--- a/ktransformers/operators/models.py
+++ b/ktransformers/operators/models.py
@@ -647,6 +647,13 @@ class KDeepseekV2Model(BaseInjectedModule):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
+ if inputs_embeds.device.type == "xpu" and position_ids is not None:
+ cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
+ position_ids)
+ position_embeddings = (cos, sin)
+ else:
+ position_embeddings = None
+
if per_layer_prefill_flag:
causal_mask = None
else:
@@ -737,6 +744,7 @@ class KDeepseekV2Model(BaseInjectedModule):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
+ position_embeddings=position_embeddings,
)
t5 = time.time()
if per_layer_prefill_flag:
diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py
index 72c8407..bbe08c8 100644
--- a/ktransformers/optimize/optimize.py
+++ b/ktransformers/optimize/optimize.py
@@ -103,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + "."
- gen_optimize_config(child, out_data, rule_list, child_prefix)
+ gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device)
def translate_model_config(model_config: PretrainedConfig):
@@ -127,8 +127,11 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
with torch.device("meta"):
inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
- load_weights(module.lm_head, weights_loader, "lm_head.")
- load_weights(module, weights_loader)
+ load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
+ load_weights(module, weights_loader, device=default_device)
module.gguf_loader = weights_loader
del_meta(module)
- torch.cuda.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif torch.xpu.is_available():
+ torch.xpu.empty_cache()
diff --git a/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml
new file mode 100644
index 0000000..5de582f
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml
@@ -0,0 +1,64 @@
+- match:
+ class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbedding
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model\\.layers\\..*" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ generate_op: "KLinearIPEXLLM"
+ prefill_op: "KLinearIPEXLLM"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "xpu"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "xpu"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm
+ replace:
+ class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+ device: "xpu"
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
\ No newline at end of file
diff --git a/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml
new file mode 100644
index 0000000..c0e46c3
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml
@@ -0,0 +1,81 @@
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ name: "^lm_head$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ generate_op: "KLinearIPEXLLM"
+ prefill_op: "KLinearIPEXLLM"
+- match:
+ name: "^model\\.layers\\..*" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ generate_op: "KLinearIPEXLLM"
+ prefill_op: "KLinearIPEXLLM"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
+ replace:
+ class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGateIPEXLLM
+ kwargs:
+ generate_device: "xpu:0"
+ prefill_device: "xpu:0"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "xpu"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "xpu"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "xpu"
+ prefill_device: "xpu"
+ absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
\ No newline at end of file
diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py
index 4518366..5e4ffd6 100644
--- a/ktransformers/util/custom_gguf.py
+++ b/ktransformers/util/custom_gguf.py
@@ -24,7 +24,8 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
-import KTransformersOps
+if not torch.xpu.is_available():
+ import KTransformersOps
import ctypes
import math
diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py
index 93b94c4..5adaaeb 100644
--- a/ktransformers/util/custom_loader.py
+++ b/ktransformers/util/custom_loader.py
@@ -7,7 +7,8 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
-import KTransformersOps
+if not torch.xpu.is_available():
+ import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from ktransformers.util.custom_gguf import *
@@ -459,7 +460,7 @@ class GGUFLoader(ModelLoader):
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
- values = torch.from_numpy(values)
+ values = torch.from_numpy(values).to(device)
if ggml_name == "BF16":
values = values.view(torch.bfloat16)
diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py
index 3ddf639..bdf19fe 100644
--- a/ktransformers/util/utils.py
+++ b/ktransformers/util/utils.py
@@ -27,7 +27,8 @@ from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
-from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
+if not torch.xpu.is_available():
+ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
warm_uped = False
@@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
return min_compute_capability_major
else:
return torch.cuda.get_device_properties(device)
+ else:
+ return 0
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
@@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list = list(all_device_list)
return all_device_list
-def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = ""):
+def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
@@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
- torch.cuda.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif torch.xpu.is_available():
+ torch.xpu.empty_cache()
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
@@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
#print(load_config.tensor_file_map.keys())
raise Exception(f"can't find {translated_key} in GGUF file!")
-def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix=''):
+
+def sync_all_device(all_device_list):
+ for device in all_device_list:
+ if "cuda" in device.lower():
+ torch.cuda.synchronize(device)
+ elif "xpu" in device.lower():
+ torch.xpu.synchronize(device)
+ else:
+ raise RuntimeError("The device {} is not available".format(device))
+
+torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
+
+def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
#print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule):
- load_cur_state_dict(module, gguf_loader, prefix)
+ load_cur_state_dict(module, gguf_loader, prefix, device=device)
for name, child in module._modules.items():
- load_weights(child, gguf_loader, prefix+name+".")
+ load_weights(child, gguf_loader, prefix+name+".", device=device)
else:
module.load()
@@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape
device_map = model.gguf_loader.tensor_device_map
- torch_device = get_device('blk.0.self_attn', device_map)
- torch_device = "cuda:0" if torch_device == "cuda" else torch_device
+ torch_device = get_device('model.layers.0.self_attn', device_map)
+ torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
@@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
- torch.cuda.set_device(torch_device)
+ if torch.cuda.is_available():
+ torch.cuda.set_device(torch_device)
+ elif torch.xpu.is_available():
+ torch.xpu.set_device(torch_device)
+ else:
+ RuntimeError("The device: {torch_device} is not available")
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
@@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
- if past_key_values != None:
+ if past_key_values != None and isinstance(past_key_values, StaticCache):
past_key_values.change_seq_length(1)
- for device in all_cuda_device:
- torch.cuda.synchronize(device)
+ sync_all_device(all_cuda_device)
#print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
@@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
return logits
- torch.cuda.set_device(torch_device)
+ if torch.cuda.is_available():
+ torch.cuda.set_device(torch_device)
+ elif torch.xpu.is_available():
+ torch.xpu.set_device(torch_device)
+ else:
+ RuntimeError("The device: {torch_device} is not available")
with torch.no_grad():
stream = TextStreamer(tokenizer)
- if mode != 'long_context':
+ if torch.xpu.is_available():
+ from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache
+ past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
+ elif mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
diff --git a/setup.py b/setup.py
index c5bf128..b8f318d 100644
--- a/setup.py
+++ b/setup.py
@@ -39,7 +39,8 @@ try:
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None
-
+KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available()
+
with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"
class CpuInstructInfo:
@@ -225,6 +226,8 @@ class VersionInfo:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
elif ROCM_HOME is not None:
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
+ elif torch.xpu.is_available():
+ backend_version = f"xpu"
else:
raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
@@ -495,6 +498,8 @@ class CMakeBuild(BuildExtension):
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
elif ROCM_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
+ elif KTRANSFORMERS_BUILD_XPU:
+ cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"]
else:
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.")
@@ -620,29 +625,36 @@ elif MUSA_HOME is not None:
]
}
)
+elif torch.xpu.is_available(): #XPUExtension is not available now.
+ ops_module = None
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
-ext_modules = [
- CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
- ops_module,
- CUDAExtension(
- 'vLLMMarlin', [
- 'csrc/custom_marlin/binding.cpp',
- 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
- 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
- ],
- extra_compile_args={
- 'cxx': ['-O3'],
- 'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
- },
- )
-]
-if with_balance:
- print("using balance_serve")
- ext_modules.append(
- CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
- )
+if not torch.xpu.is_available():
+ ext_modules = [
+ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
+ ops_module,
+ CUDAExtension(
+ 'vLLMMarlin', [
+ 'csrc/custom_marlin/binding.cpp',
+ 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
+ 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
+ ],
+ extra_compile_args={
+ 'cxx': ['-O3'],
+ 'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
+ },
+ )
+ ]
+ if with_balance:
+ print("using balance_serve")
+ ext_modules.append(
+ CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
+ )
+else:
+ ext_modules = [
+ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
+ ]
setup(
name=VersionInfo.PACKAGE_NAME,