diff --git a/doc/en/xpu.md b/doc/en/xpu.md index ffbd030..e8b3b90 100644 --- a/doc/en/xpu.md +++ b/doc/en/xpu.md @@ -47,7 +47,7 @@ strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm): ```bash -pip install ipex-llm[xpu_2.6]==2.3.0rc1 --extra-index-url https://download.pytorch.org/whl/xpu +pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu pip uninstall torch torchvision torchaudio pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7 pip uninstall intel-opencl-rt dpcpp-cpp-rt diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 173ce07..cdc934c 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM -from ktransformers.util.utils import prefill_and_generate, get_compute_capability +from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model from ktransformers.server.config.config import Config from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor @@ -78,7 +78,7 @@ def local_chat( if mode == 'long_context': assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" torch.set_default_dtype(torch.float16) - elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM": + elif xpu_fp16_model(config): torch.set_default_dtype(torch.float16) else: torch.set_default_dtype(config.torch_dtype) @@ -94,11 +94,16 @@ def local_chat( config._attn_implementation = "eager" if "Mixtral" in config.architectures[0]: config._attn_implementation = "flash_attention_2" - + if torch.xpu.is_available(): + config._attn_implementation = "eager" model = custom_models[config.architectures[0]](config) else: + if torch.xpu.is_available(): + attn_implementation = "eager" + else: + attn_implementation = "flash_attention_2" model = AutoModelForCausalLM.from_config( - config, trust_remote_code=True, attn_implementation="flash_attention_2" + config, trust_remote_code=True, attn_implementation=attn_implementation ) if optimize_config_path is None: diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 41dbf5a..9dfdbdc 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -13,6 +13,7 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader @@ -870,3 +871,75 @@ class KLlamaAttention(BaseInjectedModule): attn_weights = None return attn_output, attn_weights, past_key_value + + +class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "xpu", + generate_device: str = "xpu", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device" + assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device" + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.Tensor], + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + bsz, q_len, _ = hidden_states.size() + input_dtype = hidden_states.dtype + hidden_shape = (*input_shape, -1, self.head_dim) + + if not hasattr(self, 'qkv_proj'): + from ipex_llm.transformers.models.common import merge_quantized_qkv + merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module) + + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, -1, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.num_key_value_heads], dim=1) + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + if position_embeddings is None: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + cos, sin = position_embeddings + + from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced + rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states.half(), value_states.half(), + self.layer_idx, cache_kwargs) + + attn_weights = None + from ipex_llm.transformers.models.common import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states.half(), key_states, value_states, + attention_mask.half(), q_len == key_states.size(2), self.scaling + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output).to(input_dtype) + return attn_output, attn_weights diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 9b13f0a..33e6a1d 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -1421,19 +1421,28 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): return final_out class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock): - def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): + def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0): orig_shape = hidden_states.shape sequence_length = orig_shape[1] hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - router_logits = self.gate(hidden_states, bsz_tensor) + if bsz_tensor is None: + router_logits = self.gate(hidden_states) + else: + router_logits = self.gate(hidden_states, bsz_tensor) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + if router_logits.device.type == "xpu": + from ipex_llm.transformers.models.common import moe_softmax_topk + selected_experts, routing_weights = moe_softmax_topk( + router_logits.half(), self.top_k, self.norm_topk_prob + ) + else: + routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py index 6d616d1..796592c 100644 --- a/ktransformers/operators/layernorm.py +++ b/ktransformers/operators/layernorm.py @@ -207,16 +207,19 @@ class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule): generate_device: str = "xpu", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) - self.orig_module.__init__(orig_module.hidden_size, + self.orig_module.__init__(orig_module.weight.shape[0], orig_module.variance_epsilon) self.eps = orig_module.variance_epsilon def forward(self, x: torch.Tensor) -> torch.Tensor: from ipex_llm.transformers.models.common import rms_norm_forward - output = rms_norm_forward(self, x.float()) + if x.dtype not in [torch.float32, torch.float16]: + output = rms_norm_forward(self, x.float()) + else: + output = rms_norm_forward(self, x) return output.to(x.dtype) def load(self): BaseInjectedModule.load(self) - if self.weight.dtype != torch.float32: + if self.weight.dtype not in [torch.float32, torch.float16]: self.weight = self.weight.float() \ No newline at end of file diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 8299d4c..e136b57 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -306,6 +306,12 @@ class KQwen2MoeModel(BaseInjectedModule): hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + if torch.xpu.is_available() and inputs_embeds.device.type == "xpu": + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + position_embeddings = None + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -369,6 +375,7 @@ class KQwen2MoeModel(BaseInjectedModule): output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) if per_layer_prefill_flag: # print(f"to cpu") @@ -376,8 +383,10 @@ class KQwen2MoeModel(BaseInjectedModule): torch.cuda.empty_cache() hidden_states = layer_outputs[0] - if use_cache: + if use_cache and len(layer_outputs) > 1: next_decoder_cache = layer_outputs[2 if output_attentions else 1] + else: + next_decoder_cache = None if output_attentions: all_self_attns += (layer_outputs[1],) @@ -396,11 +405,14 @@ class KQwen2MoeModel(BaseInjectedModule): next_cache = None if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) + if next_decoder_cache is not None: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + else: + next_cache = past_key_values if not return_dict: return tuple( diff --git a/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml b/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml new file mode 100644 index 0000000..6bb4dae --- /dev/null +++ b/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml @@ -0,0 +1,80 @@ +- match: + name: "rotary_emb$" + replace: + class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\.(?!.*mlp\\.gate).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearIPEXLLM" + prefill_op: "KLinearIPEXLLM" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "xpu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "xpu" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KQwen3MoeAttentionIPEXLLM + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KQwen2MoeModel" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" +- match: + class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm + replace: + class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP + replace: + class: ktransformers.operators.mlp.KQwen2MoeMLP + kwargs: + generate_device: "xpu" + prefill_device: "xpu" diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py index 5adaaeb..edb92de 100644 --- a/ktransformers/util/custom_loader.py +++ b/ktransformers/util/custom_loader.py @@ -459,9 +459,10 @@ class GGUFLoader(ModelLoader): if "cuda" in device.lower(): values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) else: - values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values).to(device) - + np_values = np.copy(GGML_DEQUANTIZE[ggml_name](data)) + values = torch.from_numpy(np_values).to(device) + del np_values + if ggml_name == "BF16": values = values.view(torch.bfloat16) diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index bdf19fe..7301572 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -144,6 +144,18 @@ def sync_all_device(all_device_list): torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"} +def xpu_fp16_model(config): + # This function is to check if we run this model on XPU with FP16 dtype + if not torch.xpu.is_available(): + return False + if config.architectures[0] == "DeepseekV3ForCausalLM": + return True + if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096: + # Qwen3-30B seems have precision issue with FP16 + # so we only use FP16 for Qwen3-235B now + return True + return False + def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"): #print(f"recursively loading weights {prefix}") if not isinstance(module, base_operator.BaseInjectedModule): @@ -277,8 +289,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud stream = TextStreamer(tokenizer) if torch.xpu.is_available(): - from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache - past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None) + from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache + if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]: + past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None) + else: + past_key_values = DynamicNormalCache.from_legacy_cache(None) elif mode != 'long_context': past_key_values = StaticCache( config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype