mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
add XPU support for qwen3moe local chat
This commit is contained in:
parent
25893366b6
commit
adc0906967
9 changed files with 223 additions and 25 deletions
|
@ -13,6 +13,7 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
|
|||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
|
@ -870,3 +871,75 @@ class KLlamaAttention(BaseInjectedModule):
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "xpu",
|
||||
generate_device: str = "xpu",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
|
||||
assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor],
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
if not hasattr(self, 'qkv_proj'):
|
||||
from ipex_llm.transformers.models.common import merge_quantized_qkv
|
||||
merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module)
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, -1, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.split([self.config.num_attention_heads,
|
||||
self.config.num_key_value_heads,
|
||||
self.config.num_key_value_heads], dim=1)
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
if position_embeddings is None:
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
|
||||
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
|
||||
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states.half(), value_states.half(),
|
||||
self.layer_idx, cache_kwargs)
|
||||
|
||||
attn_weights = None
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states.half(), key_states, value_states,
|
||||
attention_mask.half(), q_len == key_states.size(2), self.scaling
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output).to(input_dtype)
|
||||
return attn_output, attn_weights
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue