Merge pull request #1298 from kvcache-ai/fix-workspace-buffer
Some checks are pending
Book-CI / test (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

update norm cpu kernel
This commit is contained in:
wang jiahao 2025-05-14 17:50:40 +08:00 committed by GitHub
commit 333351c7c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -163,3 +163,34 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.hidden_size,
orig_module.variance_epsilon)
def forward(
self,
x,
batch_size_tensor: torch.Tensor = None,
residual: Optional[torch.Tensor] = None,
)-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
x = x + residual
residual = x
# range batch_size_tensor for x
input_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
if residual is not None:
return self.weight * x.to(input_dtype), residual
return self.weight * x.to(input_dtype)