kvcache-ai-ktransformers/archive/ktransformers/operators/gate.py
Jiaqi Liao 57d14d22bc
Refactor: restructure repository to focus on kt-kernel and KT-SFT modulesq recon (#1581)
* refactor: move legacy code to archive/ directory

  - Moved ktransformers, csrc, third_party, merge_tensors to archive/
  - Moved build scripts and configurations to archive/
  - Kept kt-kernel, KT-SFT, doc, and README files in root
  - Preserved complete git history for all moved files

* refactor: restructure repository to focus on kt-kernel and KT-SFT modules

* fix README

* fix README

* fix README

* fix README

* docs: add performance benchmarks to kt-kernel section

Add comprehensive performance data for kt-kernel to match KT-SFT's presentation:
- AMX kernel optimization: 21.3 TFLOPS (3.9× faster than PyTorch)
- Prefill phase: up to 20× speedup vs baseline
- Decode phase: up to 4× speedup
- NUMA optimization: up to 63% throughput improvement
- Multi-GPU (8×L20): 227.85 tokens/s total throughput with DeepSeek-R1 FP8

Source: https://lmsys.org/blog/2025-10-22-KTransformers/

This provides users with concrete performance metrics for both core modules,
making it easier to understand the capabilities of each component.

* refactor: improve kt-kernel performance data with specific hardware and models

Replace generic performance descriptions with concrete benchmarks:
- Specify exact hardware: 8×L20 GPU + Xeon Gold 6454S, Single/Dual-socket Xeon + AMX
- Include specific models: DeepSeek-R1-0528 (FP8), DeepSeek-V3 (671B)
- Show detailed metrics: total throughput, output throughput, concurrency details
- Match KT-SFT presentation style for consistency

This provides users with actionable performance data they can use to evaluate
hardware requirements and expected performance for their use cases.

* fix README

* docs: clean up performance table and improve formatting

* add pic for README

* refactor: simplify .gitmodules and backup legacy submodules

- Remove 7 legacy submodules from root .gitmodules (archive/third_party/*)
- Keep only 2 active submodules for kt-kernel (llama.cpp, pybind11)
- Backup complete .gitmodules to archive/.gitmodules
- Add documentation in archive/README.md for researchers who need legacy submodules

This reduces initial clone size by ~500MB and avoids downloading unused dependencies.

* refactor: move doc/ back to root directory

Keep documentation in root for easier access and maintenance.

* refactor: consolidate all images to doc/assets/

- Move kt-kernel/assets/heterogeneous_computing.png to doc/assets/
- Remove KT-SFT/assets/ (images already in doc/assets/)
- Update KT-SFT/README.md image references to ../doc/assets/
- Eliminates ~7.9MB image duplication
- Centralizes all documentation assets in one location

* fix pic path for README
2025-11-10 17:42:26 +08:00

226 lines
9.4 KiB
Python

from typing import Optional
from torch import nn
import torch
import torch.nn.functional as F
import os
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.linear import KTransformersLinear
from ktransformers.util.custom_loader import GGUFLoader, ModelLoader, SafeTensorLoader, translate_name_to_gguf
from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod
# class Base(BaseInjectedModule, ABC):
class KMoEGateBase(ABC):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
**kwargs):
# super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
super().__init__()
self.key = key
self.gguf_loader = gguf_loader
self.config = config
self.device = device
self.orig_module = orig_module
@abstractmethod
def forward(self, input_tensor, expert_ids, weights):
pass
@abstractmethod
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu", warmup: bool = False):
pass
@abstractmethod
def unload():
pass
def load_weights(self, override_key: str | None = None, device: str = "cpu"):
res = {}
if override_key is not None:
keys = override_key
else:
keys = [self.key]
gate = None
up = None
down = None
gate_type = None
up_type = None
down_type = None
for key in keys:
if self.gguf_loader.safetensor_loader is not None:
# for npu
translate_key = translate_name_to_gguf(key)
translate_key = ".".join(translate_key.split(".")[:2])
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
weight = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".ffn_gate_inp.weight")
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".exp_probs_b.bias")
weight_type = weight.dtype
e_score_correction_bias_type = e_score_correction_bias.dtype
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
# key = ".".join(key.split(".")[:-1])
elif isinstance(self.gguf_loader, SafeTensorLoader):
res = self.gguf_loader.load_gate(key, device=device)
elif self.gguf_loader.has_tensor(key+".weight"):
# targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
targets = [".weight", ".e_score_correction_bias"]
tensors = self.load_multi(key, targets, device=device)
weight = tensors[".weight"]
e_score_correction_bias = tensors[".e_score_correction_bias"]
# weight_type = self.gguf_loader.tensor_info[key + ".weight"]["ggml_type"]
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias}
else:
raise ValueError(f"Experts {key} not found in gguf_loader")
return res
def load_multi(self, key: str, keys: list[str], device: str = "cpu"):
tensors = {}
for k in keys:
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
return tensors
class KMoEGate(BaseInjectedModule, KMoEGateBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
def forward(self, hidden_states) -> torch.Tensor:
return self.orig_module.forward(hidden_states)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
def unload(self):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda",
prefill_op: str| None = "KLinearMarlin",
use_quant: bool = False,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
self.generate_op = generate_op
self.prefill_op = prefill_op
self.is_windows = os.name == 'nt'
self.use_quant = use_quant
if not self.is_windows and use_quant:
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
gguf_loader, config, self.gate_linear, #orig_module
generate_device, generate_op, prefill_device, prefill_op)
else:
self.gate_linear = None
def forward(self, hidden_states) -> torch.Tensor:
if self.is_windows:
return self.orig_module.forward(hidden_states)
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
if self.use_quant:
logits = self.gate_linear.forward(logits)
else:
logits = F.linear(
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
)
return grouped_topk(hidden_states, logits,
self.top_k, self.norm_topk_prob,
self.n_group, self.topk_group)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
if not self.is_windows and self.use_quant:
self.gate_linear.load(self.orig_module.weight)
def unload(self):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
class KMoEGateIPEXLLM(KMoEGate):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "xpu",
prefill_device: str = "xpu",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
def forward(self, hidden_states) -> torch.Tensor:
x = hidden_states.view(-1, hidden_states.size(-1))
logits = torch.nn.functional.linear(
x.type(torch.float32), self.orig_module.weight.type(torch.float32), None
)
scores = logits.sigmoid()
from ipex_llm.transformers.models.common import moe_group_topk
topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,
self.n_group, self.topk_group, self.top_k,
self.norm_topk_prob, self.routed_scaling_factor)
return topk_idx, topk_weight.to(x.dtype)