Kt minimax (#1742)

[feat]: fp8 kernel and kt-cli support
This commit is contained in:
ErvinXie 2025-12-24 15:39:44 +08:00 committed by GitHub
parent e7d277d163
commit d8046e1bb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 12111 additions and 2502 deletions

View file

@ -4,13 +4,13 @@
Utilities for kt_kernel package.
"""
from .amx import AMXMoEWrapper, RAWAMXMoEWrapper
from .amx import AMXMoEWrapper, NativeMoEWrapper
from .llamafile import LlamafileMoEWrapper
from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader
__all__ = [
"AMXMoEWrapper",
"RAWAMXMoEWrapper",
"NativeMoEWrapper",
"LlamafileMoEWrapper",
"SafeTensorLoader",
"CompressedSafeTensorLoader",

View file

@ -4,16 +4,16 @@ import ctypes
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader, CompressedSafeTensorLoader
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE
_HAS_AMX_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE = None, None, None, None
from typing import Optional
@ -303,10 +303,10 @@ class AMXMoEWrapper(BaseMoEWrapper):
del self.down_scales
class RAWAMXMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
class NativeMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4/FP8 experts stored in compressed SafeTensor format."""
_compressed_loader_instance = None
_native_loader_instance = None
def __init__(
self,
@ -324,8 +324,12 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
max_deferred_experts_per_token: Optional[int] = None,
method: str = "RAWINT4",
):
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
if not _HAS_AMX_SUPPORT:
raise RuntimeError("AMX backend is not available.")
if method == "RAWINT4" and AMXInt4_KGroup_MOE is None:
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
if method == "FP8" and AMXFP8_MOE is None:
raise RuntimeError("AMX backend with FP8 support is not available.")
super().__init__(
layer_idx=layer_idx,
@ -343,9 +347,14 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
method=method,
)
if RAWAMXMoEWrapper._compressed_loader_instance is None:
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
if NativeMoEWrapper._native_loader_instance is None:
if method == "RAWINT4":
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
elif method == "FP8":
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
self.loader = NativeMoEWrapper._native_loader_instance
self.gate_weights = None
self.up_weights = None
@ -378,9 +387,17 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
self.down_weights = weights["down"]
# Convert scales to bf16 individually
self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
self.gate_scales = weights["gate_scale"]
self.up_scales = weights["up_scale"]
self.down_scales = weights["down_scale"]
if self.method == "RAWINT4":
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
elif self.method == "FP8":
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
t2 = time.time()
# Build pointer lists: [numa_id][expert_id] -> pointer
@ -404,18 +421,6 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
# Infer group_size from scale shape (column-major layout)
# For gate/up projection: in_features = hidden_size
# So: group_size = hidden_size / scale.shape[1]
scale_shape = self.gate_scales[0].shape
group_size = self.hidden_size // scale_shape[1]
print(f"[RAWAMXMoEWrapper Layer {self.layer_idx}] Inferred group_size: {group_size}")
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = group_size
moe_config.quant_config.zero_point = False
# Use gate_projs instead of gate_proj for per-expert pointers
moe_config.gate_projs = gate_ptrs
moe_config.up_projs = up_ptrs
@ -424,7 +429,21 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs
self.moe = AMXInt4_KGroup_MOE(moe_config)
# Infer group_size from scale shape (column-major layout)
# For gate/up projection: in_features = hidden_size
# So: group_size = hidden_size / scale.shape[1]
if self.method == "RAWINT4":
group_size = self.hidden_size // self.gate_scales[0].shape[1]
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = group_size
moe_config.quant_config.zero_point = False
self.moe = AMXInt4_KGroup_MOE(moe_config)
elif self.method == "FP8":
moe_config.quant_config.bits = 8
moe_config.quant_config.group_size = 128
moe_config.quant_config.zero_point = False
self.moe = AMXFP8_MOE(moe_config)
t4 = time.time()
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
@ -440,7 +459,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
t6 = time.time()
print(
f"[RAWAMXMoEWrapper Layer {self.layer_idx}] "
f"[NativeMoEWrapper Layer {self.layer_idx}] "
f"load_experts: {(t1-t0)*1000:.1f}ms, "
f"prepare_tensors: {(t2-t1)*1000:.1f}ms, "
f"build_ptrs: {(t3-t2)*1000:.1f}ms, "
@ -453,7 +472,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
def submit_write_weight_scale_to_buffer(
self,
gpu_tp_count: int,
gpu_experts_num: int,
expert_id: int,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,
@ -477,7 +496,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
self.cpu_infer.submit(
self.moe.write_weight_scale_to_buffer_task(
gpu_tp_count,
gpu_experts_num,
expert_id,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,

View file

@ -219,4 +219,4 @@ class LlamafileMoEWrapper(BaseMoEWrapper):
self.cpu_infer.sync()
# Drop original weights after loading
self.weights_to_keep = None
self.weights_to_keep = None

View file

@ -237,6 +237,117 @@ class SafeTensorLoader:
return name in self.tensor_file_map
class FP8SafeTensorLoader(SafeTensorLoader):
"""Loader for FP8 expert weights with auto-detection of naming formats.
Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
The format is auto-detected during initialization.
"""
# Known MoE naming formats: (experts_path_template, gate_name, up_name, down_name)
MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str):
super().__init__(file_path)
self._detected_format = None
self._detect_format()
def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
# Sample some tensor names to detect format
sample_keys = list(self.tensor_file_map.keys())[:1000]
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
# Check if any key matches this format pattern
# Look for pattern like: model.layers.0.{experts_path}.0.{gate_name}.weight
for key in sample_keys:
if ".experts." in key and f".{gate}.weight" in key:
# Verify the path template matches
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
return
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
return
# Default to deepseek if no format detected
self._detected_format = "deepseek"
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)
def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
return gate, up, down
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if device == "cpu":
return tensor
return tensor.to(device)
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load FP8 expert weights and their block-wise scale_inv tensors."""
experts_prefix = self._get_experts_prefix(base_key)
gate_name, up_name, down_name = self._get_proj_names()
expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count
gate_scales = [None] * expert_count
up_scales = [None] * expert_count
down_scales = [None] * expert_count
for exp_id in range(expert_count):
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight_scale_inv"
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.weight_scale_inv"
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.weight_scale_inv"
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
gate_scales[exp_id] = self.load_tensor(gate_s_key, device).contiguous()
up_scales[exp_id] = self.load_tensor(up_s_key, device).contiguous()
down_scales[exp_id] = self.load_tensor(down_s_key, device).contiguous()
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
"gate_scale": gate_scales,
"up_scale": up_scales,
"down_scale": down_scales,
}
class CompressedSafeTensorLoader(SafeTensorLoader):
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""