mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 04:30:03 +00:00
Merge pull request #1298 from kvcache-ai/fix-workspace-buffer
update norm cpu kernel
This commit is contained in:
commit
333351c7c8
1 changed files with 31 additions and 0 deletions
|
@ -163,3 +163,34 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
|
|||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
)-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
# range batch_size_tensor for x
|
||||
input_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
if residual is not None:
|
||||
return self.weight * x.to(input_dtype), residual
|
||||
return self.weight * x.to(input_dtype)
|
Loading…
Add table
Reference in a new issue