mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
add XPU support for qwen3moe local chat
This commit is contained in:
parent
25893366b6
commit
adc0906967
9 changed files with 223 additions and 25 deletions
|
@ -207,16 +207,19 @@ class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
|
|||
generate_device: str = "xpu",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
self.orig_module.__init__(orig_module.weight.shape[0],
|
||||
orig_module.variance_epsilon)
|
||||
self.eps = orig_module.variance_epsilon
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||
output = rms_norm_forward(self, x.float())
|
||||
if x.dtype not in [torch.float32, torch.float16]:
|
||||
output = rms_norm_forward(self, x.float())
|
||||
else:
|
||||
output = rms_norm_forward(self, x)
|
||||
return output.to(x.dtype)
|
||||
|
||||
def load(self):
|
||||
BaseInjectedModule.load(self)
|
||||
if self.weight.dtype != torch.float32:
|
||||
if self.weight.dtype not in [torch.float32, torch.float16]:
|
||||
self.weight = self.weight.float()
|
Loading…
Add table
Add a link
Reference in a new issue