mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-04 11:40:13 +00:00
add XPU support for qwen3moe local chat
This commit is contained in:
parent
25893366b6
commit
adc0906967
9 changed files with 223 additions and 25 deletions
|
@ -47,7 +47,7 @@ strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
|
|||
Install PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm):
|
||||
|
||||
```bash
|
||||
pip install ipex-llm[xpu_2.6]==2.3.0rc1 --extra-index-url https://download.pytorch.org/whl/xpu
|
||||
pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu
|
||||
pip uninstall torch torchvision torchaudio
|
||||
pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7
|
||||
pip uninstall intel-opencl-rt dpcpp-cpp-rt
|
||||
|
|
|
@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
|||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
|
||||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
|
@ -78,7 +78,7 @@ def local_chat(
|
|||
if mode == 'long_context':
|
||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||
torch.set_default_dtype(torch.float16)
|
||||
elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
elif xpu_fp16_model(config):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
else:
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
|
@ -94,11 +94,16 @@ def local_chat(
|
|||
config._attn_implementation = "eager"
|
||||
if "Mixtral" in config.architectures[0]:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
if torch.xpu.is_available():
|
||||
config._attn_implementation = "eager"
|
||||
model = custom_models[config.architectures[0]](config)
|
||||
else:
|
||||
if torch.xpu.is_available():
|
||||
attn_implementation = "eager"
|
||||
else:
|
||||
attn_implementation = "flash_attention_2"
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
||||
config, trust_remote_code=True, attn_implementation=attn_implementation
|
||||
)
|
||||
|
||||
if optimize_config_path is None:
|
||||
|
|
|
@ -13,6 +13,7 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
|
|||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
|
@ -870,3 +871,75 @@ class KLlamaAttention(BaseInjectedModule):
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "xpu",
|
||||
generate_device: str = "xpu",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
|
||||
assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor],
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
if not hasattr(self, 'qkv_proj'):
|
||||
from ipex_llm.transformers.models.common import merge_quantized_qkv
|
||||
merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module)
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, -1, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.split([self.config.num_attention_heads,
|
||||
self.config.num_key_value_heads,
|
||||
self.config.num_key_value_heads], dim=1)
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
if position_embeddings is None:
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
|
||||
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
|
||||
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states.half(), value_states.half(),
|
||||
self.layer_idx, cache_kwargs)
|
||||
|
||||
attn_weights = None
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states.half(), key_states, value_states,
|
||||
attention_mask.half(), q_len == key_states.size(2), self.scaling
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output).to(input_dtype)
|
||||
return attn_output, attn_weights
|
||||
|
|
|
@ -1421,19 +1421,28 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
|||
return final_out
|
||||
|
||||
class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
if bsz_tensor is None:
|
||||
router_logits = self.gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
if router_logits.device.type == "xpu":
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
|
|
|
@ -207,16 +207,19 @@ class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
|
|||
generate_device: str = "xpu",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
self.orig_module.__init__(orig_module.weight.shape[0],
|
||||
orig_module.variance_epsilon)
|
||||
self.eps = orig_module.variance_epsilon
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||
output = rms_norm_forward(self, x.float())
|
||||
if x.dtype not in [torch.float32, torch.float16]:
|
||||
output = rms_norm_forward(self, x.float())
|
||||
else:
|
||||
output = rms_norm_forward(self, x)
|
||||
return output.to(x.dtype)
|
||||
|
||||
def load(self):
|
||||
BaseInjectedModule.load(self)
|
||||
if self.weight.dtype != torch.float32:
|
||||
if self.weight.dtype not in [torch.float32, torch.float16]:
|
||||
self.weight = self.weight.float()
|
|
@ -306,6 +306,12 @@ class KQwen2MoeModel(BaseInjectedModule):
|
|||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
if torch.xpu.is_available() and inputs_embeds.device.type == "xpu":
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
else:
|
||||
position_embeddings = None
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
@ -369,6 +375,7 @@ class KQwen2MoeModel(BaseInjectedModule):
|
|||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
if per_layer_prefill_flag:
|
||||
# print(f"to cpu")
|
||||
|
@ -376,8 +383,10 @@ class KQwen2MoeModel(BaseInjectedModule):
|
|||
torch.cuda.empty_cache()
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
if use_cache and len(layer_outputs) > 1:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
else:
|
||||
next_decoder_cache = None
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
@ -396,11 +405,14 @@ class KQwen2MoeModel(BaseInjectedModule):
|
|||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if use_legacy_cache
|
||||
else next_decoder_cache
|
||||
)
|
||||
if next_decoder_cache is not None:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if use_legacy_cache
|
||||
else next_decoder_cache
|
||||
)
|
||||
else:
|
||||
next_cache = past_key_values
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
|
|
80
ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml
Normal file
80
ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml
Normal file
|
@ -0,0 +1,80 @@
|
|||
- match:
|
||||
name: "rotary_emb$"
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "xpu"
|
||||
prefill_device: "xpu"
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "xpu"
|
||||
prefill_device: "xpu"
|
||||
generate_op: "KLinearIPEXLLM"
|
||||
prefill_op: "KLinearIPEXLLM"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*mlp\\.gate).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "xpu"
|
||||
prefill_device: "xpu"
|
||||
generate_op: "KLinearIPEXLLM"
|
||||
prefill_op: "KLinearIPEXLLM"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "xpu"
|
||||
prefill_device: "xpu"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "xpu"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "xpu"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KQwen3MoeAttentionIPEXLLM
|
||||
kwargs:
|
||||
generate_device: "xpu"
|
||||
prefill_device: "xpu"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KQwen2MoeModel"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
- match:
|
||||
class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm
|
||||
replace:
|
||||
class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM
|
||||
kwargs:
|
||||
generate_device: "xpu"
|
||||
prefill_device: "xpu"
|
||||
- match:
|
||||
class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP
|
||||
replace:
|
||||
class: ktransformers.operators.mlp.KQwen2MoeMLP
|
||||
kwargs:
|
||||
generate_device: "xpu"
|
||||
prefill_device: "xpu"
|
|
@ -459,9 +459,10 @@ class GGUFLoader(ModelLoader):
|
|||
if "cuda" in device.lower():
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values).to(device)
|
||||
|
||||
np_values = np.copy(GGML_DEQUANTIZE[ggml_name](data))
|
||||
values = torch.from_numpy(np_values).to(device)
|
||||
del np_values
|
||||
|
||||
if ggml_name == "BF16":
|
||||
values = values.view(torch.bfloat16)
|
||||
|
||||
|
|
|
@ -144,6 +144,18 @@ def sync_all_device(all_device_list):
|
|||
|
||||
torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
|
||||
|
||||
def xpu_fp16_model(config):
|
||||
# This function is to check if we run this model on XPU with FP16 dtype
|
||||
if not torch.xpu.is_available():
|
||||
return False
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
return True
|
||||
if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096:
|
||||
# Qwen3-30B seems have precision issue with FP16
|
||||
# so we only use FP16 for Qwen3-235B now
|
||||
return True
|
||||
return False
|
||||
|
||||
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
|
||||
#print(f"recursively loading weights {prefix}")
|
||||
if not isinstance(module, base_operator.BaseInjectedModule):
|
||||
|
@ -277,8 +289,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
|
||||
stream = TextStreamer(tokenizer)
|
||||
if torch.xpu.is_available():
|
||||
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache
|
||||
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
|
||||
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache
|
||||
if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
|
||||
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
|
||||
else:
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(None)
|
||||
elif mode != 'long_context':
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
|
|
Loading…
Add table
Reference in a new issue