mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-05 15:40:13 +00:00
* refactor: move legacy code to archive/ directory - Moved ktransformers, csrc, third_party, merge_tensors to archive/ - Moved build scripts and configurations to archive/ - Kept kt-kernel, KT-SFT, doc, and README files in root - Preserved complete git history for all moved files * refactor: restructure repository to focus on kt-kernel and KT-SFT modules * fix README * fix README * fix README * fix README * docs: add performance benchmarks to kt-kernel section Add comprehensive performance data for kt-kernel to match KT-SFT's presentation: - AMX kernel optimization: 21.3 TFLOPS (3.9× faster than PyTorch) - Prefill phase: up to 20× speedup vs baseline - Decode phase: up to 4× speedup - NUMA optimization: up to 63% throughput improvement - Multi-GPU (8×L20): 227.85 tokens/s total throughput with DeepSeek-R1 FP8 Source: https://lmsys.org/blog/2025-10-22-KTransformers/ This provides users with concrete performance metrics for both core modules, making it easier to understand the capabilities of each component. * refactor: improve kt-kernel performance data with specific hardware and models Replace generic performance descriptions with concrete benchmarks: - Specify exact hardware: 8×L20 GPU + Xeon Gold 6454S, Single/Dual-socket Xeon + AMX - Include specific models: DeepSeek-R1-0528 (FP8), DeepSeek-V3 (671B) - Show detailed metrics: total throughput, output throughput, concurrency details - Match KT-SFT presentation style for consistency This provides users with actionable performance data they can use to evaluate hardware requirements and expected performance for their use cases. * fix README * docs: clean up performance table and improve formatting * add pic for README * refactor: simplify .gitmodules and backup legacy submodules - Remove 7 legacy submodules from root .gitmodules (archive/third_party/*) - Keep only 2 active submodules for kt-kernel (llama.cpp, pybind11) - Backup complete .gitmodules to archive/.gitmodules - Add documentation in archive/README.md for researchers who need legacy submodules This reduces initial clone size by ~500MB and avoids downloading unused dependencies. * refactor: move doc/ back to root directory Keep documentation in root for easier access and maintenance. * refactor: consolidate all images to doc/assets/ - Move kt-kernel/assets/heterogeneous_computing.png to doc/assets/ - Remove KT-SFT/assets/ (images already in doc/assets/) - Update KT-SFT/README.md image references to ../doc/assets/ - Eliminates ~7.9MB image duplication - Centralizes all documentation assets in one location * fix pic path for README
350 lines
No EOL
15 KiB
Python
350 lines
No EOL
15 KiB
Python
'''
|
|
Date: 2024-11-13 15:05:52
|
|
LastEditors: Xie Weiyu ervinxie@qq.com
|
|
LastEditTime: 2024-11-25 08:59:19
|
|
'''
|
|
"""
|
|
Copyright 2023-2024 SGLang Team
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
"""Fused operators for normalization layers."""
|
|
|
|
import logging
|
|
from typing import Optional, Tuple, Union
|
|
from transformers import PretrainedConfig
|
|
import torch
|
|
import torch.nn as nn
|
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
|
|
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
|
|
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
|
from ktransformers.models.modeling_qwen3_next import Qwen3NextRMSNorm
|
|
from ktransformers.models.modeling_smallthinker import SmallthinkerRMSNorm
|
|
from ktransformers.models.modeling_glm4_moe import Glm4MoeRMSNorm
|
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
|
from ktransformers.util.custom_loader import GGUFLoader
|
|
if not torch.xpu.is_available():
|
|
from flashinfer.norm import (
|
|
fused_add_rmsnorm,
|
|
rmsnorm,
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
|
|
def __init__(self,
|
|
key: str,
|
|
gguf_loader : GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
prefill_device: str = "cuda",
|
|
generate_device: str = "cuda",
|
|
**kwargs):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
self.orig_module.__init__(orig_module.hidden_size,
|
|
orig_module.variance_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
batch_size_tensor: torch.Tensor = None,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
#return self.forward_native(x, residual)
|
|
if batch_size_tensor is None:
|
|
return self.forward_native(x)
|
|
if residual is not None:
|
|
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
#residual = x + residual
|
|
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
return x, residual
|
|
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
|
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
|
return out
|
|
|
|
def forward_native(
|
|
self, hidden_states
|
|
):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
class KQwen2MoeRMSNorm(Qwen2MoeRMSNorm, BaseInjectedModule):
|
|
def __init__(self,
|
|
key: str,
|
|
gguf_loader : GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
prefill_device: str = "cuda",
|
|
generate_device: str = "cuda",
|
|
**kwargs):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
self.orig_module.__init__(config.hidden_size,
|
|
orig_module.variance_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
batch_size_tensor: torch.Tensor = None,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
#return self.forward_native(x, residual)
|
|
if batch_size_tensor is None:
|
|
return self.forward_native(x)
|
|
if residual is not None:
|
|
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
#residual = x + residual
|
|
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
return x, residual
|
|
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
|
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
|
return out
|
|
|
|
def forward_native(
|
|
self, hidden_states
|
|
):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
|
|
def __init__(self,
|
|
key: str,
|
|
gguf_loader : GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
prefill_device: str = "cuda",
|
|
generate_device: str = "cuda",
|
|
**kwargs):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
self.orig_module.__init__(orig_module.hidden_size,
|
|
orig_module.variance_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
batch_size_tensor: torch.Tensor = None,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
#return self.forward_native(x, residual)
|
|
bsz, hidden_size = x.shape
|
|
x = x.view(-1, self.orig_module.hidden_size)
|
|
if batch_size_tensor is None:
|
|
return self.forward_native(x)
|
|
if residual is not None:
|
|
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
#residual = x + residual
|
|
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
return x, residual
|
|
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
|
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
|
out = out.view(bsz, hidden_size)
|
|
return out
|
|
|
|
def forward_native(
|
|
self, hidden_states
|
|
):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
class KQwen3NextRMSNorm(Qwen3NextRMSNorm, BaseInjectedModule):
|
|
def __init__(self,
|
|
key: str,
|
|
gguf_loader : GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
prefill_device: str = "cuda",
|
|
generate_device: str = "cuda",
|
|
**kwargs):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
self.orig_module.__init__(orig_module.hidden_size,
|
|
orig_module.variance_epsilon)
|
|
|
|
def _norm(self, x):
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x, num_tokens_tensors, residual = None):
|
|
if residual is not None:
|
|
x = x + residual
|
|
residual = x
|
|
x = x.view(-1, self.orig_module.hidden_size)
|
|
output = self._norm(x.float())
|
|
# Llama does x.to(float16) * w whilst Qwen3Next is (x * w).to(float16)
|
|
# See https://github.com/huggingface/transformers/pull/29402
|
|
output = output * (1.0 + self.weight.float())
|
|
if residual is None:
|
|
return output.type_as(x)
|
|
|
|
return output.type_as(x), residual
|
|
|
|
def extra_repr(self):
|
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
|
|
|
|
|
class KSmallthinkerRMSNorm(SmallthinkerRMSNorm, BaseInjectedModule):
|
|
def __init__(self,
|
|
key: str,
|
|
gguf_loader : GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
prefill_device: str = "cuda",
|
|
generate_device: str = "cuda",
|
|
**kwargs):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
self.orig_module.__init__(orig_module.hidden_size,
|
|
orig_module.variance_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
batch_size_tensor: torch.Tensor = None,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
#return self.forward_native(x, residual)
|
|
bsz, hidden_size = x.shape
|
|
x = x.view(-1, self.orig_module.hidden_size)
|
|
if batch_size_tensor is None:
|
|
return self.forward_native(x)
|
|
if residual is not None:
|
|
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
#residual = x + residual
|
|
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
return x, residual
|
|
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
|
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
|
out = out.view(bsz, hidden_size)
|
|
return out
|
|
|
|
def forward_native(
|
|
self, hidden_states
|
|
):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
class KGlm4MoeRMSNorm(Glm4MoeRMSNorm, BaseInjectedModule):
|
|
def __init__(self,
|
|
key: str,
|
|
gguf_loader : GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
prefill_device: str = "cuda",
|
|
generate_device: str = "cuda",
|
|
**kwargs):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
self.orig_module.__init__(orig_module.hidden_size,
|
|
orig_module.variance_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
batch_size_tensor: torch.Tensor = None,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
#return self.forward_native(x, residual)
|
|
bsz, hidden_size = x.shape
|
|
x = x.view(-1, self.orig_module.hidden_size)
|
|
if batch_size_tensor is None:
|
|
return self.forward_native(x)
|
|
if residual is not None:
|
|
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
#residual = x + residual
|
|
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
|
return x, residual
|
|
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
|
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
|
out = out.view(bsz, hidden_size)
|
|
return out
|
|
|
|
def forward_native(
|
|
self, hidden_states
|
|
):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
|
|
class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
|
|
def __init__(self,
|
|
key: str,
|
|
gguf_loader : GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
prefill_device: str = "cuda",
|
|
generate_device: str = "cuda",
|
|
**kwargs):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
self.orig_module.__init__(orig_module.hidden_size,
|
|
orig_module.variance_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
batch_size_tensor: torch.Tensor = None,
|
|
residual: Optional[torch.Tensor] = None,
|
|
)-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
if residual is not None:
|
|
x = x + residual
|
|
residual = x
|
|
# range batch_size_tensor for x
|
|
input_dtype = x.dtype
|
|
x = x.to(torch.float32)
|
|
variance = x.pow(2).mean(-1, keepdim=True)
|
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
|
if residual is not None:
|
|
return self.weight * x.to(input_dtype), residual
|
|
return self.weight * x.to(input_dtype)
|
|
|
|
|
|
class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):
|
|
def __init__(self,
|
|
key: str,
|
|
gguf_loader : GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
prefill_device: str = "xpu",
|
|
generate_device: str = "xpu",
|
|
**kwargs):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
self.orig_module.__init__(orig_module.weight.shape[0],
|
|
orig_module.variance_epsilon)
|
|
self.eps = orig_module.variance_epsilon
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
|
if x.dtype not in [torch.float32, torch.float16]:
|
|
output = rms_norm_forward(self, x.float())
|
|
else:
|
|
output = rms_norm_forward(self, x)
|
|
return output.to(x.dtype)
|
|
|
|
def load(self):
|
|
BaseInjectedModule.load(self)
|
|
if self.weight.dtype not in [torch.float32, torch.float16]:
|
|
self.weight = self.weight.float() |