diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py index 06c569b..22d580b 100644 --- a/ktransformers/operators/layernorm.py +++ b/ktransformers/operators/layernorm.py @@ -163,3 +163,34 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule): variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + +class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.hidden_size, + orig_module.variance_epsilon) + + def forward( + self, + x, + batch_size_tensor: torch.Tensor = None, + residual: Optional[torch.Tensor] = None, + )-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + x = x + residual + residual = x + # range batch_size_tensor for x + input_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + if residual is not None: + return self.weight * x.to(input_dtype), residual + return self.weight * x.to(input_dtype) \ No newline at end of file