add XPU support for qwen3moe local chat

This commit is contained in:
rnwang04 2025-05-21 18:33:41 +08:00
parent 25893366b6
commit adc0906967
9 changed files with 223 additions and 25 deletions

View file

@ -144,6 +144,18 @@ def sync_all_device(all_device_list):
torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
def xpu_fp16_model(config):
# This function is to check if we run this model on XPU with FP16 dtype
if not torch.xpu.is_available():
return False
if config.architectures[0] == "DeepseekV3ForCausalLM":
return True
if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096:
# Qwen3-30B seems have precision issue with FP16
# so we only use FP16 for Qwen3-235B now
return True
return False
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
#print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule):
@ -277,8 +289,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
stream = TextStreamer(tokenizer)
if torch.xpu.is_available():
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache
if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
else:
past_key_values = DynamicNormalCache.from_legacy_cache(None)
elif mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype