add XPU support for qwen3moe local chat

This commit is contained in:
rnwang04 2025-05-21 18:33:41 +08:00
parent 25893366b6
commit adc0906967
9 changed files with 223 additions and 25 deletions

View file

@ -47,7 +47,7 @@ strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm):
```bash
pip install ipex-llm[xpu_2.6]==2.3.0rc1 --extra-index-url https://download.pytorch.org/whl/xpu
pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu
pip uninstall torch torchvision torchaudio
pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7
pip uninstall intel-opencl-rt dpcpp-cpp-rt

View file

@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
@ -78,7 +78,7 @@ def local_chat(
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM":
elif xpu_fp16_model(config):
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
@ -94,11 +94,16 @@ def local_chat(
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
if torch.xpu.is_available():
config._attn_implementation = "eager"
model = custom_models[config.architectures[0]](config)
else:
if torch.xpu.is_available():
attn_implementation = "eager"
else:
attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
config, trust_remote_code=True, attn_implementation=attn_implementation
)
if optimize_config_path is None:

View file

@ -13,6 +13,7 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
@ -870,3 +871,75 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights = None
return attn_output, attn_weights, past_key_value
class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "xpu",
generate_device: str = "xpu",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.Tensor],
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
bsz, q_len, _ = hidden_states.size()
input_dtype = hidden_states.dtype
hidden_shape = (*input_shape, -1, self.head_dim)
if not hasattr(self, 'qkv_proj'):
from ipex_llm.transformers.models.common import merge_quantized_qkv
merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module)
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, -1, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.config.num_attention_heads,
self.config.num_key_value_heads,
self.config.num_key_value_heads], dim=1)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
if position_embeddings is None:
position_embeddings = self.rotary_emb(hidden_states, position_ids)
cos, sin = position_embeddings
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states.half(), value_states.half(),
self.layer_idx, cache_kwargs)
attn_weights = None
from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states.half(), key_states, value_states,
attention_mask.half(), q_len == key_states.size(2), self.scaling
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output).to(input_dtype)
return attn_output, attn_weights

View file

@ -1421,19 +1421,28 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return final_out
class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
router_logits = self.gate(hidden_states, bsz_tensor)
if bsz_tensor is None:
router_logits = self.gate(hidden_states)
else:
router_logits = self.gate(hidden_states, bsz_tensor)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
if router_logits.device.type == "xpu":
from ipex_llm.transformers.models.common import moe_softmax_topk
selected_experts, routing_weights = moe_softmax_topk(
router_logits.half(), self.top_k, self.norm_topk_prob
)
else:
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

View file

@ -207,16 +207,19 @@ class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
generate_device: str = "xpu",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.hidden_size,
self.orig_module.__init__(orig_module.weight.shape[0],
orig_module.variance_epsilon)
self.eps = orig_module.variance_epsilon
def forward(self, x: torch.Tensor) -> torch.Tensor:
from ipex_llm.transformers.models.common import rms_norm_forward
output = rms_norm_forward(self, x.float())
if x.dtype not in [torch.float32, torch.float16]:
output = rms_norm_forward(self, x.float())
else:
output = rms_norm_forward(self, x)
return output.to(x.dtype)
def load(self):
BaseInjectedModule.load(self)
if self.weight.dtype != torch.float32:
if self.weight.dtype not in [torch.float32, torch.float16]:
self.weight = self.weight.float()

View file

@ -306,6 +306,12 @@ class KQwen2MoeModel(BaseInjectedModule):
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
if torch.xpu.is_available() and inputs_embeds.device.type == "xpu":
position_embeddings = self.rotary_emb(hidden_states, position_ids)
else:
position_embeddings = None
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@ -369,6 +375,7 @@ class KQwen2MoeModel(BaseInjectedModule):
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
if per_layer_prefill_flag:
# print(f"to cpu")
@ -376,8 +383,10 @@ class KQwen2MoeModel(BaseInjectedModule):
torch.cuda.empty_cache()
hidden_states = layer_outputs[0]
if use_cache:
if use_cache and len(layer_outputs) > 1:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
else:
next_decoder_cache = None
if output_attentions:
all_self_attns += (layer_outputs[1],)
@ -396,11 +405,14 @@ class KQwen2MoeModel(BaseInjectedModule):
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if next_decoder_cache is not None:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
else:
next_cache = past_key_values
if not return_dict:
return tuple(

View file

@ -0,0 +1,80 @@
- match:
name: "rotary_emb$"
replace:
class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
generate_op: "KLinearIPEXLLM"
prefill_op: "KLinearIPEXLLM"
- match:
name: "^model\\.layers\\.(?!.*mlp\\.gate).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
generate_op: "KLinearIPEXLLM"
prefill_op: "KLinearIPEXLLM"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
replace:
class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "xpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "xpu"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KQwen3MoeAttentionIPEXLLM
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KQwen2MoeModel"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm
replace:
class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP
replace:
class: ktransformers.operators.mlp.KQwen2MoeMLP
kwargs:
generate_device: "xpu"
prefill_device: "xpu"

View file

@ -459,9 +459,10 @@ class GGUFLoader(ModelLoader):
if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values).to(device)
np_values = np.copy(GGML_DEQUANTIZE[ggml_name](data))
values = torch.from_numpy(np_values).to(device)
del np_values
if ggml_name == "BF16":
values = values.view(torch.bfloat16)

View file

@ -144,6 +144,18 @@ def sync_all_device(all_device_list):
torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
def xpu_fp16_model(config):
# This function is to check if we run this model on XPU with FP16 dtype
if not torch.xpu.is_available():
return False
if config.architectures[0] == "DeepseekV3ForCausalLM":
return True
if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096:
# Qwen3-30B seems have precision issue with FP16
# so we only use FP16 for Qwen3-235B now
return True
return False
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
#print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule):
@ -277,8 +289,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
stream = TextStreamer(tokenizer)
if torch.xpu.is_available():
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache
if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
else:
past_key_values = DynamicNormalCache.from_legacy_cache(None)
elif mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype