mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first
This commit is contained in:
parent
333351c7c8
commit
142fb7ce6c
22 changed files with 673 additions and 81 deletions
|
@ -30,10 +30,11 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
|
|||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
from flashinfer.norm import (
|
||||
fused_add_rmsnorm,
|
||||
rmsnorm,
|
||||
)
|
||||
if not torch.xpu.is_available():
|
||||
from flashinfer.norm import (
|
||||
fused_add_rmsnorm,
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -193,4 +194,29 @@ class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
|
|||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
if residual is not None:
|
||||
return self.weight * x.to(input_dtype), residual
|
||||
return self.weight * x.to(input_dtype)
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
|
||||
class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "xpu",
|
||||
generate_device: str = "xpu",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
self.eps = orig_module.variance_epsilon
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||
output = rms_norm_forward(self, x.float())
|
||||
return output.to(x.dtype)
|
||||
|
||||
def load(self):
|
||||
BaseInjectedModule.load(self)
|
||||
if self.weight.dtype != torch.float32:
|
||||
self.weight = self.weight.float()
|
Loading…
Add table
Add a link
Reference in a new issue