add XPU support for qwen3moe local chat

This commit is contained in:
rnwang04 2025-05-21 18:33:41 +08:00
parent 25893366b6
commit adc0906967
9 changed files with 223 additions and 25 deletions

View file

@ -207,16 +207,19 @@ class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
generate_device: str = "xpu",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.hidden_size,
self.orig_module.__init__(orig_module.weight.shape[0],
orig_module.variance_epsilon)
self.eps = orig_module.variance_epsilon
def forward(self, x: torch.Tensor) -> torch.Tensor:
from ipex_llm.transformers.models.common import rms_norm_forward
output = rms_norm_forward(self, x.float())
if x.dtype not in [torch.float32, torch.float16]:
output = rms_norm_forward(self, x.float())
else:
output = rms_norm_forward(self, x)
return output.to(x.dtype)
def load(self):
BaseInjectedModule.load(self)
if self.weight.dtype != torch.float32:
if self.weight.dtype not in [torch.float32, torch.float16]:
self.weight = self.weight.float()