From 142fb7ce6c2739eda197624155374a711a88c01e Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Wed, 14 May 2025 14:28:22 +0000 Subject: [PATCH] Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first --- README.md | 2 + csrc/ktransformers_ext/CMakeLists.txt | 4 + csrc/ktransformers_ext/cpu_backend/cpuinfer.h | 9 ++ csrc/ktransformers_ext/ext_bindings.cpp | 2 +- doc/README.md | 2 +- doc/en/xpu.md | 117 ++++++++++++++++++ ktransformers/local_chat.py | 13 +- ktransformers/models/custom_cache.py | 4 +- ktransformers/models/modeling_deepseek.py | 1 + ktransformers/operators/attention.py | 113 ++++++++++++++++- ktransformers/operators/experts.py | 43 +++++-- ktransformers/operators/gate.py | 32 ++++- ktransformers/operators/layernorm.py | 36 +++++- ktransformers/operators/linear.py | 92 ++++++++++++-- ktransformers/operators/models.py | 8 ++ ktransformers/optimize/optimize.py | 11 +- .../optimize_rules/xpu/DeepSeek-V2-Chat.yaml | 64 ++++++++++ .../optimize_rules/xpu/DeepSeek-V3-Chat.yaml | 81 ++++++++++++ ktransformers/util/custom_gguf.py | 3 +- ktransformers/util/custom_loader.py | 5 +- ktransformers/util/utils.py | 58 ++++++--- setup.py | 54 ++++---- 22 files changed, 673 insertions(+), 81 deletions(-) create mode 100644 doc/en/xpu.md create mode 100644 ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml create mode 100644 ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml diff --git a/README.md b/README.md index 6ad3a59..172694a 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+ **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)). + * **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md)) https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2 diff --git a/csrc/ktransformers_ext/CMakeLists.txt b/csrc/ktransformers_ext/CMakeLists.txt index 217de78..0ed4ef4 100644 --- a/csrc/ktransformers_ext/CMakeLists.txt +++ b/csrc/ktransformers_ext/CMakeLists.txt @@ -41,6 +41,7 @@ option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON) option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF) option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF) +option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF) # Architecture specific # TODO: probably these flags need to be tweaked on some architectures @@ -303,6 +304,8 @@ elseif (UNIX) message(STATUS "MUSA Toolkit found") add_compile_definitions(KTRANSFORMERS_USE_MUSA=1) endif() + elseif (KTRANSFORMERS_USE_XPU) + add_compile_definitions(KTRANSFORMERS_USE_XPU=1) else() find_package(CUDA REQUIRED) include_directories("${CUDA_INCLUDE_DIRS}") @@ -361,6 +364,7 @@ elseif(UNIX) message(STATUS "Building for HIP") elseif(KTRANSFORMERS_USE_MUSA) target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart) + elseif(KTRANSFORMERS_USE_XPU) else() target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so") endif() diff --git a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h index 9c7e781..7b1d898 100644 --- a/csrc/ktransformers_ext/cpu_backend/cpuinfer.h +++ b/csrc/ktransformers_ext/cpu_backend/cpuinfer.h @@ -17,6 +17,7 @@ #include #include #include + #include #ifdef KTRANSFORMERS_USE_CUDA #include "vendors/cuda.h" #elif KTRANSFORMERS_USE_MUSA @@ -66,10 +67,14 @@ } void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) { + #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM) void (*func)(void*) = (void (*)(void*))params.first; void* args = (void*)params.second; *((CPUInfer**)args) = this; cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args); + #else + throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma"); + #endif } static void sync_(void* cpu_infer_ptr) { @@ -78,7 +83,11 @@ } void sync_with_cuda_stream(intptr_t user_cuda_stream) { + #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM) cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this); + #else + throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma"); + #endif } public: diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp index 2767679..f0aeaa5 100644 --- a/csrc/ktransformers_ext/ext_bindings.cpp +++ b/csrc/ktransformers_ext/ext_bindings.cpp @@ -9,7 +9,7 @@ **/ // Python bindings #include "cpu_backend/cpuinfer.h" -#ifndef KTRANSFORMERS_USE_ROCM +#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) #include "device_launch_parameters.h" #endif #include "llamafile/flags.h" diff --git a/doc/README.md b/doc/README.md index 05df2d3..1e994aa 100644 --- a/doc/README.md +++ b/doc/README.md @@ -21,7 +21,7 @@ interface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.

🔥 Updates

- +* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./en/xpu.md)). * **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./en/llama4.md)). * **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./en/balance-serve.md)). * **Mar 27, 2025**: Support Multi-concurrency. diff --git a/doc/en/xpu.md b/doc/en/xpu.md new file mode 100644 index 0000000..4f71ae7 --- /dev/null +++ b/doc/en/xpu.md @@ -0,0 +1,117 @@ +# Intel GPU Support for KTransformers (Beta) + +## Introduction + +### Overview +We are excited to introduce **Intel GPU support** in KTransformers (Beta release). This implementation has been tested and developed using Intel Xeon Scalable processors and Intel Arc GPU's (such as A770 and B580). + +## Installation Guide + +### 1. Install Intel GPU Driver +Begin by installing the GPU drivers for your Intel GPU: +- [Official GPU Installation Guide for Intel GPUs](https://dgpu-docs.intel.com/driver/overview.html) + +> [!Important] +> Ensure that **Resizable BAR** is enabled in your system's BIOS before proceeding. This is essential for optimal GPU performance and to avoid potential issues such as `Bus error (core dumped)`. For detailed steps, please refer to the official guidance [here](https://www.intel.com/content/www/us/en/support/articles/000090831/graphics.html). + +### 2. Set Up Conda Environment +We recommend using Miniconda3/Anaconda3 for environment management: + +```bash +# Download Miniconda +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh + +# Create environment +conda create --name ktransformers python=3.11 +conda activate ktransformers + +# Install required libraries +conda install -c conda-forge libstdcxx-ng + +# Verify GLIBCXX version (should include 3.4.32) +strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX +``` + +> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3` + +### 3. Install PyTorch and IPEX-LLM +Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm): + +```bash +pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu +pip uninstall torch torchvision torchaudio +pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7 +pip install packaging ninja cpufeature numpy +pip uninstall intel-opencl-rt dpcpp-cpp-rt +``` + +### 4. Build ktransformers + +```bash +# Clone repository +git clone https://github.com/kvcache-ai/ktransformers.git +cd ktransformers +git submodule update --init + +# Install dependencies +bash install.sh +pip uninstall triton pytorch-triton-xpu +pip install pytorch-triton-xpu==3.3.0 --extra-index-url https://download.pytorch.org/whl/xpu # to avoid potential triton import error +``` + +## Running DeepSeek-R1 Models + +### Configuration for 16B VRAM GPUs +Use our optimized configuration for constrained VRAM: + +```bash +export SYCL_CACHE_PERSISTENT=1 +export ONEAPI_DEVICE_SELECTOR=level_zero:0 +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 + +python ktransformers/local_chat.py \ + --model_path deepseek-ai/DeepSeek-R1 \ + --gguf_path \ + --optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \ + --cpu_infer \ + --device xpu \ + --max_new_tokens 200 +``` + +## Known Limitations +- Serving function is not supported on Intel GPU platform for now + +## Troubleshooting +1. Best Known Config (BKC) to obtain best performance + +To obtain best performance on Intel GPU platform, we recommand to lock GPU frequency and set CPU to performance mode by below settings. +```bash +echo "performance" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +echo 0 | sudo tee /sys/devices/system/cpu/cpu*/power/energy_perf_bias +# 2400 is max frequency for Arc A770 +sudo xpu-smi config -d 0 -t 0 --frequencyrange 2400,2400 +# 2850 is max frequency for Arc B580 +# sudo xpu-smi config -d 0 -t 0 --frequencyrange 2850,2850 +``` + +2. Runtime error like `xpu/sycl/TensorCompareKernels.cpp:163: xxx. Aborted (core dumped)` + +This error is mostly realted to GPU driver. If you meet such error, you could update your `intel-level-zero-gpu` to `1.3.29735.27-914~22.04` (which is a verified version by us) by below command. +```bash +wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \ +sudo gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg +echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy client" | \ +sudo tee /etc/apt/sources.list.d/intel-gpu-jammy.list +sudo apt update +# or sudo apt update --allow-insecure-repositories +sudo apt install intel-level-zero-gpu=1.3.29735.27-914~22.04 +``` + +3. `ImportError: cannot import name 'intel' from 'triton._C.libtriton'` + +Installing Triton causes pytorch-triton-xpu to stop working. You can resolve the issue with following command: +```bash +pip uninstall triton pytorch-triton-xpu +# Reinstall correct version of pytorch-triton-xpu +pip install pytorch-triton-xpu==3.3.0 --index-url https://download.pytorch.org/whl/xpu +``` \ No newline at end of file diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 928de48..173ce07 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -63,18 +63,23 @@ def local_chat( prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, - chunk_size: int = 8192 + chunk_size: int = 8192, + device: str = "cuda" ): torch.set_grad_enabled(False) Config().cpu_infer = cpu_infer + if torch.xpu.is_available(): + use_cuda_graph = False tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if mode == 'long_context': assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" torch.set_default_dtype(torch.float16) + elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM": + torch.set_default_dtype(torch.float16) else: torch.set_default_dtype(config.torch_dtype) @@ -109,7 +114,7 @@ def local_chat( gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" ) - optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) + optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, default_device=device) try: model.generation_config = GenerationConfig.from_pretrained(model_path) @@ -172,12 +177,12 @@ def local_chat( if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA: generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, + model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim ) else: generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, + model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, ) diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index c2901ac..1121b4a 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -213,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module): self.v_caches = [] - def load(self, inference_context: "sched_ext.InferenceContext"): + def load(self, inference_context: "sched_ext.InferenceContext"): for i in range(self.config.num_hidden_layers): self.k_caches.append( @@ -293,7 +293,7 @@ class KGQACache(nn.Module): self.v_caches = [] - def load(self, inference_context: sched_ext.InferenceContext): + def load(self, inference_context: "sched_ext.InferenceContext"): print(self.config.num_hidden_layers) for i in range(self.config.num_hidden_layers): self.k_caches.append( diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py index e14a521..f6845ec 100644 --- a/ktransformers/models/modeling_deepseek.py +++ b/ktransformers/models/modeling_deepseek.py @@ -107,6 +107,7 @@ class DeepseekV2RMSNorm(nn.Module): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.hidden_size = hidden_size def forward(self, hidden_states): input_dtype = hidden_states.dtype diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index caceb98..41dbf5a 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -587,6 +587,100 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): return attn_output, None, past_key_value + def forward_xpu( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + position_embeddings = kwargs.get("position_embeddings", None) + if position_embeddings is not None: + cos, sin = position_embeddings + key_states = torch.cat( + [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])], + dim=-1 + ) + from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced + rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :], + key_states[:, :, :, self.qk_nope_head_dim:], + cos, sin, True) + else: + q_nope, q_pe = torch.split( + query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + cos, sin = self.rotary_emb(q_pe, position_ids) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states.half(), value_states.half(), self.layer_idx, cache_kwargs + ) + + attn_weights = None + from ipex_llm.transformers.models.common import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states.half(), key_states, value_states, + attention_mask.half(), q_len == kv_seq_len, self.softmax_scale + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output).to(hidden_states.dtype) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + def forward( self, hidden_states: torch.Tensor, @@ -598,10 +692,21 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if (os.name == 'nt' - or get_compute_capability() < 8 - or hidden_states.device.type == 'cpu' - or device_manager.gpu_vendor != GPUVendor.NVIDIA): + if torch.xpu.is_available(): + return self.forward_xpu( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + elif (os.name == 'nt' + or get_compute_capability() < 8 + or hidden_states.device.type == 'cpu' + or device_manager.gpu_vendor != GPUVendor.NVIDIA): return self.forward_windows( hidden_states, attention_mask, diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 279439d..d7d1926 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -51,7 +51,10 @@ def generate_cuda_graphs(chunk_size: int) -> list: return deduplicate_and_sort(base_list + multiples) #cuda_graphs = [Config().chunk_size] -cuda_graphs = generate_cuda_graphs(Config().chunk_size) +if torch.cuda.is_available(): + cuda_graphs = generate_cuda_graphs(Config().chunk_size) +else: + cuda_graphs = 1 # class Base(BaseInjectedModule, ABC): class KExpertsBase(ABC): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs): @@ -177,6 +180,11 @@ class KExpertsCPU(KExpertsBase): n_routed_experts = self.n_routed_experts self.cpu_infer = KExpertsCPU.CPU_INFER # n_routed_experts = len(self.orig_module) + model_dtype = torch.get_default_dtype() + if torch.xpu.is_available() and model_dtype == torch.float16: + hidden_type = 1 # fp16 + else: + hidden_type = 30 # bf16 if self.backend == "llamafile": moe_config = MOEConfig( n_routed_experts, @@ -192,7 +200,7 @@ class KExpertsCPU(KExpertsBase): self.gate_type, self.up_type, self.down_type, - 30, # TODO: get from model.dtype + hidden_type, # TODO: get from model.dtype ) self.moe = MOE(moe_config) elif self.backend == "AMXBF16": @@ -252,8 +260,12 @@ class KExpertsCPU(KExpertsBase): KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True) KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) - KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) - KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) + if torch.xpu.is_available(): + KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype) + KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True) + else: + KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) + KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): if bsz_tensor is None: @@ -285,9 +297,9 @@ class KExpertsCPU(KExpertsBase): def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): # generate, capture and run cuda graph # print(expert_ids) - if bsz_tensor is None: + if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1): bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32) - if torch.cuda.is_current_stream_capturing(): + if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): if cuda_graph_idx != -1: KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True) @@ -307,6 +319,15 @@ class KExpertsCPU(KExpertsBase): self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device] + elif input_tensor.size(0)==1 and torch.xpu.is_available(): + KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True) + KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True) + KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True) + # KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True) + self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr())) + self.cpu_infer.sync() + KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) + return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1) else: input_tensor = input_tensor.contiguous().cpu() expert_ids = expert_ids.contiguous().cpu() @@ -822,7 +843,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) @@ -922,7 +943,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # only for generate phase - if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) @@ -1122,7 +1143,7 @@ class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE): # only for generate phase - if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity, bsz_tensor).squeeze(0) @@ -1304,7 +1325,7 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): routing_weights = routing_weights.to(hidden_states.dtype) # only for generate phase - if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ @@ -1417,7 +1438,7 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock): routing_weights = routing_weights.to(hidden_states.dtype) # only for generate phase - if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index a6f95ac..f5f96c1 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -182,4 +182,34 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase): if self.weight is not None: self.weight = None if self.e_score_correction_bias is not None: - self.e_score_correction_bias = None \ No newline at end of file + self.e_score_correction_bias = None + + +class KMoEGateIPEXLLM(KMoEGate): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + generate_device: str = "xpu", + prefill_device: str = "xpu", + **kwargs, + ): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) + KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + self.generate_device = generate_device + self.prefill_device = prefill_device + + def forward(self, hidden_states) -> torch.Tensor: + x = hidden_states.view(-1, hidden_states.size(-1)) + logits = torch.nn.functional.linear( + x.type(torch.float32), self.orig_module.weight.type(torch.float32), None + ) + scores = logits.sigmoid() + + from ipex_llm.transformers.models.common import moe_group_topk + topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias, + self.n_group, self.topk_group, self.top_k, + self.norm_topk_prob, self.routed_scaling_factor) + return topk_idx, topk_weight.to(x.dtype) \ No newline at end of file diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py index 22d580b..6d616d1 100644 --- a/ktransformers/operators/layernorm.py +++ b/ktransformers/operators/layernorm.py @@ -30,10 +30,11 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader -from flashinfer.norm import ( - fused_add_rmsnorm, - rmsnorm, -) +if not torch.xpu.is_available(): + from flashinfer.norm import ( + fused_add_rmsnorm, + rmsnorm, + ) logger = logging.getLogger(__name__) @@ -193,4 +194,29 @@ class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule): x = x * torch.rsqrt(variance + self.variance_epsilon) if residual is not None: return self.weight * x.to(input_dtype), residual - return self.weight * x.to(input_dtype) \ No newline at end of file + return self.weight * x.to(input_dtype) + + +class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "xpu", + generate_device: str = "xpu", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.hidden_size, + orig_module.variance_epsilon) + self.eps = orig_module.variance_epsilon + + def forward(self, x: torch.Tensor) -> torch.Tensor: + from ipex_llm.transformers.models.common import rms_norm_forward + output = rms_norm_forward(self, x.float()) + return output.to(x.dtype) + + def load(self): + BaseInjectedModule.load(self) + if self.weight.dtype != torch.float32: + self.weight = self.weight.float() \ No newline at end of file diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 2b12b15..9ce45d1 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -14,18 +14,20 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved. import ctypes import torch from torch import Tensor, nn -import KTransformersOps -import vLLMMarlin +if not torch.xpu.is_available(): + import KTransformersOps + import vLLMMarlin from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader from ktransformers.util.utils import InferenceState -from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( - MarlinWorkspace, - marlin_quantize, - GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MIN_THREAD_K, - GPTQ_MARLIN_MAX_PARALLEL, - vllm_marlin_quantize -) +if not torch.xpu.is_available(): + from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( + MarlinWorkspace, + marlin_quantize, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MIN_THREAD_K, + GPTQ_MARLIN_MAX_PARALLEL, + vllm_marlin_quantize + ) from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant @@ -778,6 +780,75 @@ class KLinearCPUInfer(KLinearBase): if self.has_bias: self.bias = None +class KLinearIPEXLLM(KLinearBase): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "xpu", + precision: str = "sym_int4", + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.has_bias = False + self.dtype = torch.get_default_dtype() + self.weight = None + self.has_bias = False + self.precision = precision + self.qtype = None + + def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor: + dtype = x.dtype + out_device = x.device + from ipex_llm.transformers.models.common import linear_forward + x = linear_forward(x.half(), self.weight, self.qtype, self.out_features) + + if self.has_bias: + x = x + self.bias + x = x.to(dtype=dtype, device=out_device) + return x + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if self.loaded: return + if device is None: device = self.device + assert device.lower()[:3] == "xpu", "IPEX-LLM quantized linear only supports XPU device" + if w is None: w = self.load_weight(device=device) + + if isinstance(w, nn.Parameter): + try: + weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T + except: + weight = w.to(dtype=self.dtype).T + self.has_bias = False + elif isinstance(w, tuple): + try: + weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T + except: + weight = w[0].to(dtype=self.dtype).T + self.bias = w[1].to(dtype=self.dtype) + self.has_bias = True + else: + raise ValueError("Invalid weight type") + weight = weight.to("cpu").float().transpose(0, 1).contiguous() + + if self.has_bias: + self.bias = self.bias.to(device) + + # quantize linear weight + from ipex_llm.transformers.models.common import quantize_linear + paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision) + self.weight = paramsLowBit.to(device) + self.qtype = qtype + self.loaded = True + + def unload(self): + if self.weight is not None: + self.weight = None + if self.has_bias: + self.bias = None + LINEAR_MAP = { "KLinearMarlin": KLinearMarlin, "KLinearTorch": KLinearTorch, @@ -785,6 +856,7 @@ LINEAR_MAP = { "VLinearMarlin": VLinearMarlin, "KLinearFP8": KLinearFP8, "KLinearQ8": KLinearQ8, + "KLinearIPEXLLM": KLinearIPEXLLM, } class KTransformersLinear(BaseInjectedModule, KLinearBase): diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 4aa223d..8299d4c 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -647,6 +647,13 @@ class KDeepseekV2Model(BaseInjectedModule): if position_ids is None: position_ids = cache_position.unsqueeze(0) + if inputs_embeds.device.type == "xpu" and position_ids is not None: + cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds, + position_ids) + position_embeddings = (cos, sin) + else: + position_embeddings = None + if per_layer_prefill_flag: causal_mask = None else: @@ -737,6 +744,7 @@ class KDeepseekV2Model(BaseInjectedModule): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) t5 = time.time() if per_layer_prefill_flag: diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py index 72c8407..bbe08c8 100644 --- a/ktransformers/optimize/optimize.py +++ b/ktransformers/optimize/optimize.py @@ -103,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + "." - gen_optimize_config(child, out_data, rule_list, child_prefix) + gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device) def translate_model_config(model_config: PretrainedConfig): @@ -127,8 +127,11 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo with torch.device("meta"): inject(module, optimize_config, model_config, weights_loader) # pre load lm_head because its big inter result - load_weights(module.lm_head, weights_loader, "lm_head.") - load_weights(module, weights_loader) + load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device) + load_weights(module, weights_loader, device=default_device) module.gguf_loader = weights_loader del_meta(module) - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() diff --git a/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml new file mode 100644 index 0000000..5de582f --- /dev/null +++ b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml @@ -0,0 +1,64 @@ +- match: + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "xpu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "xpu" + recursive: False # don't recursively inject submodules of this module +- match: + class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm + replace: + class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + device: "xpu" +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml new file mode 100644 index 0000000..c0e46c3 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml @@ -0,0 +1,81 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\..*" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm + replace: + class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGateIPEXLLM + kwargs: + generate_device: "xpu:0" + prefill_device: "xpu:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "xpu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "xpu" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 4518366..5e4ffd6 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -24,7 +24,8 @@ from typing import Sequence import os from enum import IntEnum import torch -import KTransformersOps +if not torch.xpu.is_available(): + import KTransformersOps import ctypes import math diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py index 93b94c4..5adaaeb 100644 --- a/ktransformers/util/custom_loader.py +++ b/ktransformers/util/custom_loader.py @@ -7,7 +7,8 @@ from typing import Sequence import os from enum import IntEnum import torch -import KTransformersOps +if not torch.xpu.is_available(): + import KTransformersOps from safetensors import safe_open from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant from ktransformers.util.custom_gguf import * @@ -459,7 +460,7 @@ class GGUFLoader(ModelLoader): values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) else: values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values) + values = torch.from_numpy(values).to(device) if ggml_name == "BF16": values = values.view(torch.bfloat16) diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 3ddf639..bdf19fe 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -27,7 +27,8 @@ from ktransformers.operators import base_operator from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.textstream import TextStreamer -from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton +if not torch.xpu.is_available(): + from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton import socket warm_uped = False @@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None): return min_compute_capability_major else: return torch.cuda.get_device_properties(device) + else: + return 0 def set_module(model, submodule_key, module): tokens = submodule_key.split('.') @@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict): all_device_list = list(all_device_list) return all_device_list -def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = ""): +def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"): prefix = prefix.replace("orig_module.", "") persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items()) @@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) set_param(module, name, weights) del weights @@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str #print(load_config.tensor_file_map.keys()) raise Exception(f"can't find {translated_key} in GGUF file!") -def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix=''): + +def sync_all_device(all_device_list): + for device in all_device_list: + if "cuda" in device.lower(): + torch.cuda.synchronize(device) + elif "xpu" in device.lower(): + torch.xpu.synchronize(device) + else: + raise RuntimeError("The device {} is not available".format(device)) + +torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"} + +def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"): #print(f"recursively loading weights {prefix}") if not isinstance(module, base_operator.BaseInjectedModule): - load_cur_state_dict(module, gguf_loader, prefix) + load_cur_state_dict(module, gguf_loader, prefix, device=device) for name, child in module._modules.items(): - load_weights(child, gguf_loader, prefix+name+".") + load_weights(child, gguf_loader, prefix+name+".", device=device) else: module.load() @@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud torch._dynamo.config.suppress_errors = True batch_size, seq_length = inputs.shape device_map = model.gguf_loader.tensor_device_map - torch_device = get_device('blk.0.self_attn', device_map) - torch_device = "cuda:0" if torch_device == "cuda" else torch_device + torch_device = get_device('model.layers.0.self_attn', device_map) + torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device inputs = inputs.to(torch_device) all_cuda_device = get_all_used_cuda_device(device_map) @@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud logits = cuda_graph_runner(cur_token, position_ids, cache_position) else: # custom_stream = torch.cuda.Stream() - torch.cuda.set_device(torch_device) + if torch.cuda.is_available(): + torch.cuda.set_device(torch_device) + elif torch.xpu.is_available(): + torch.xpu.set_device(torch_device) + else: + RuntimeError("The device: {torch_device} is not available") inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device) # with torch.cuda.stream(custom_stream): logits=model(inputs_embeds=inputs_embeds, @@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True)[0] - if past_key_values != None: + if past_key_values != None and isinstance(past_key_values, StaticCache): past_key_values.change_seq_length(1) - for device in all_cuda_device: - torch.cuda.synchronize(device) + sync_all_device(all_cuda_device) #print(logits) next_token_scores = logits_warper(inputs, logits[:, -1, :]) if generation_config.do_sample: @@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud return logits - torch.cuda.set_device(torch_device) + if torch.cuda.is_available(): + torch.cuda.set_device(torch_device) + elif torch.xpu.is_available(): + torch.xpu.set_device(torch_device) + else: + RuntimeError("The device: {torch_device} is not available") with torch.no_grad(): stream = TextStreamer(tokenizer) - if mode != 'long_context': + if torch.xpu.is_available(): + from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache + past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None) + elif mode != 'long_context': past_key_values = StaticCache( config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype ) diff --git a/setup.py b/setup.py index c5bf128..b8f318d 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,8 @@ try: from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME except ImportError: MUSA_HOME=None - +KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available() + with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1" class CpuInstructInfo: @@ -225,6 +226,8 @@ class VersionInfo: backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" elif ROCM_HOME is not None: backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}" + elif torch.xpu.is_available(): + backend_version = f"xpu" else: raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.") package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}" @@ -495,6 +498,8 @@ class CMakeBuild(BuildExtension): cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"] elif ROCM_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"] + elif KTRANSFORMERS_BUILD_XPU: + cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"] else: raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set.") @@ -620,29 +625,36 @@ elif MUSA_HOME is not None: ] } ) +elif torch.xpu.is_available(): #XPUExtension is not available now. + ops_module = None else: raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") -ext_modules = [ - CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), - ops_module, - CUDAExtension( - 'vLLMMarlin', [ - 'csrc/custom_marlin/binding.cpp', - 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', - 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', - ], - extra_compile_args={ - 'cxx': ['-O3'], - 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], - }, - ) -] -if with_balance: - print("using balance_serve") - ext_modules.append( - CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) - ) +if not torch.xpu.is_available(): + ext_modules = [ + CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), + ops_module, + CUDAExtension( + 'vLLMMarlin', [ + 'csrc/custom_marlin/binding.cpp', + 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', + 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', + ], + extra_compile_args={ + 'cxx': ['-O3'], + 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], + }, + ) + ] + if with_balance: + print("using balance_serve") + ext_modules.append( + CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) + ) +else: + ext_modules = [ + CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), + ] setup( name=VersionInfo.PACKAGE_NAME,