add XPU support for qwen3moe local chat

This commit is contained in:
rnwang04 2025-05-21 18:33:41 +08:00
parent 25893366b6
commit adc0906967
9 changed files with 223 additions and 25 deletions

View file

@ -13,6 +13,7 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
@ -870,3 +871,75 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights = None
return attn_output, attn_weights, past_key_value
class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "xpu",
generate_device: str = "xpu",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.Tensor],
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
bsz, q_len, _ = hidden_states.size()
input_dtype = hidden_states.dtype
hidden_shape = (*input_shape, -1, self.head_dim)
if not hasattr(self, 'qkv_proj'):
from ipex_llm.transformers.models.common import merge_quantized_qkv
merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module)
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, -1, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.config.num_attention_heads,
self.config.num_key_value_heads,
self.config.num_key_value_heads], dim=1)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
if position_embeddings is None:
position_embeddings = self.rotary_emb(hidden_states, position_ids)
cos, sin = position_embeddings
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states.half(), value_states.half(),
self.layer_idx, cache_kwargs)
attn_weights = None
from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states.half(), key_states, value_states,
attention_mask.half(), q_len == key_states.size(2), self.scaling
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output).to(input_dtype)
return attn_output, attn_weights