mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
add XPU support for qwen3moe local chat
This commit is contained in:
parent
25893366b6
commit
adc0906967
9 changed files with 223 additions and 25 deletions
|
@ -13,6 +13,7 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
|
|||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
|
@ -870,3 +871,75 @@ class KLlamaAttention(BaseInjectedModule):
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "xpu",
|
||||
generate_device: str = "xpu",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
|
||||
assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor],
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
if not hasattr(self, 'qkv_proj'):
|
||||
from ipex_llm.transformers.models.common import merge_quantized_qkv
|
||||
merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module)
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, -1, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.split([self.config.num_attention_heads,
|
||||
self.config.num_key_value_heads,
|
||||
self.config.num_key_value_heads], dim=1)
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
if position_embeddings is None:
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
|
||||
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
|
||||
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states.half(), value_states.half(),
|
||||
self.layer_idx, cache_kwargs)
|
||||
|
||||
attn_weights = None
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states.half(), key_states, value_states,
|
||||
attention_mask.half(), q_len == key_states.size(2), self.scaling
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output).to(input_dtype)
|
||||
return attn_output, attn_weights
|
||||
|
|
|
@ -1421,19 +1421,28 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
|||
return final_out
|
||||
|
||||
class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
if bsz_tensor is None:
|
||||
router_logits = self.gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
if router_logits.device.type == "xpu":
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
|
|
|
@ -207,16 +207,19 @@ class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
|
|||
generate_device: str = "xpu",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
self.orig_module.__init__(orig_module.weight.shape[0],
|
||||
orig_module.variance_epsilon)
|
||||
self.eps = orig_module.variance_epsilon
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||
output = rms_norm_forward(self, x.float())
|
||||
if x.dtype not in [torch.float32, torch.float16]:
|
||||
output = rms_norm_forward(self, x.float())
|
||||
else:
|
||||
output = rms_norm_forward(self, x)
|
||||
return output.to(x.dtype)
|
||||
|
||||
def load(self):
|
||||
BaseInjectedModule.load(self)
|
||||
if self.weight.dtype != torch.float32:
|
||||
if self.weight.dtype not in [torch.float32, torch.float16]:
|
||||
self.weight = self.weight.float()
|
|
@ -306,6 +306,12 @@ class KQwen2MoeModel(BaseInjectedModule):
|
|||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
if torch.xpu.is_available() and inputs_embeds.device.type == "xpu":
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
else:
|
||||
position_embeddings = None
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
@ -369,6 +375,7 @@ class KQwen2MoeModel(BaseInjectedModule):
|
|||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
if per_layer_prefill_flag:
|
||||
# print(f"to cpu")
|
||||
|
@ -376,8 +383,10 @@ class KQwen2MoeModel(BaseInjectedModule):
|
|||
torch.cuda.empty_cache()
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
if use_cache and len(layer_outputs) > 1:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
else:
|
||||
next_decoder_cache = None
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
@ -396,11 +405,14 @@ class KQwen2MoeModel(BaseInjectedModule):
|
|||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if use_legacy_cache
|
||||
else next_decoder_cache
|
||||
)
|
||||
if next_decoder_cache is not None:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if use_legacy_cache
|
||||
else next_decoder_cache
|
||||
)
|
||||
else:
|
||||
next_cache = past_key_values
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue