mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
* support Kimi-K2-Thinking original weight fix amx kernel bug * update k2 avx kernel. * feat: add CPUInfer write buffer task * [feat]: add kimi k2 cpu write buffer support - Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights - Fix down (w2) weight column-wise slicing for different TP configurations - Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp - Add comprehensive test cases for weight extraction validation - Ensure compatibility with Kimi model's MoE architecture * [fix]: correct write_weight_scale_to_buffer expert offset calculation Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios. Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * [feat]: add write buffer wrapper * [fix] fix comment --------- Co-authored-by: ouqingliang <1692110604@qq.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
c2b8c60c4e
commit
fcf8882075
12 changed files with 2649 additions and 34 deletions
|
|
@ -17,7 +17,7 @@ from typing import List, Optional
|
|||
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
|
||||
|
||||
# Import backend implementations
|
||||
from .utils.amx import AMXMoEWrapper
|
||||
from .utils.amx import AMXMoEWrapper, RAWAMXMoEWrapper
|
||||
from .utils.llamafile import LlamafileMoEWrapper
|
||||
from .utils.moe_kernel import GeneralMoEWrapper
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ class KTMoEWrapper:
|
|||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
|
||||
|
|
@ -85,6 +85,8 @@ class KTMoEWrapper:
|
|||
# Select backend based on method
|
||||
if method in ["AMXINT4", "AMXINT8"]:
|
||||
backend_cls = AMXMoEWrapper
|
||||
elif method == "RAWINT4":
|
||||
backend_cls = RAWAMXMoEWrapper
|
||||
elif method == "LLAMAFILE":
|
||||
backend_cls = LlamafileMoEWrapper
|
||||
elif method in ["MOE_INT4", "MOE_INT8"]:
|
||||
|
|
|
|||
|
|
@ -4,13 +4,15 @@
|
|||
Utilities for kt_kernel package.
|
||||
"""
|
||||
|
||||
from .amx import AMXMoEWrapper
|
||||
from .amx import AMXMoEWrapper, RAWAMXMoEWrapper
|
||||
from .llamafile import LlamafileMoEWrapper
|
||||
from .loader import SafeTensorLoader, GGUFLoader
|
||||
from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader
|
||||
|
||||
__all__ = [
|
||||
"AMXMoEWrapper",
|
||||
"RAWAMXMoEWrapper",
|
||||
"LlamafileMoEWrapper",
|
||||
"SafeTensorLoader",
|
||||
"CompressedSafeTensorLoader",
|
||||
"GGUFLoader",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,16 +4,16 @@ import ctypes
|
|||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
|
||||
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE = None, None
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -301,3 +301,152 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
|||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
|
||||
|
||||
class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
|
||||
|
||||
_compressed_loader_instance = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "RAWINT4",
|
||||
):
|
||||
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
|
||||
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_gpu_experts=num_gpu_experts,
|
||||
cpuinfer_threads=cpuinfer_threads,
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=weight_path,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
)
|
||||
|
||||
if RAWAMXMoEWrapper._compressed_loader_instance is None:
|
||||
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
|
||||
|
||||
self.gate_weights = None
|
||||
self.up_weights = None
|
||||
self.down_weights = None
|
||||
self.gate_scales = None
|
||||
self.up_scales = None
|
||||
self.down_scales = None
|
||||
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
weights = self.loader.load_experts(base_key)
|
||||
|
||||
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
|
||||
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
|
||||
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
|
||||
|
||||
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = 32
|
||||
moe_config.quant_config.zero_point = False
|
||||
|
||||
moe_config.gate_proj = self.gate_weights.data_ptr()
|
||||
moe_config.up_proj = self.up_weights.data_ptr()
|
||||
moe_config.down_proj = self.down_weights.data_ptr()
|
||||
moe_config.gate_scale = self.gate_scales.data_ptr()
|
||||
moe_config.up_scale = self.up_scales.data_ptr()
|
||||
moe_config.down_scale = self.down_scales.data_ptr()
|
||||
|
||||
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
del self.gate_weights
|
||||
del self.up_weights
|
||||
del self.down_weights
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
|
||||
def submit_write_weight_scale_to_buffer(
|
||||
self,
|
||||
gpu_tp_count: int,
|
||||
gpu_experts_num: int,
|
||||
w13_weight_ptrs,
|
||||
w13_scale_ptrs,
|
||||
w2_weight_ptrs,
|
||||
w2_scale_ptrs,
|
||||
):
|
||||
"""
|
||||
Submit the write_weight_scale_to_buffer task for RAWINT4 KGroup AMX implementation.
|
||||
|
||||
This method submits the C++-exposed task `write_weight_scale_to_buffer_task` to the
|
||||
shared CPUInfer queue. The pointer lists should be plain integer lists (e.g. from
|
||||
tensor.data_ptr()).
|
||||
"""
|
||||
if self.moe is None:
|
||||
raise RuntimeError("MoE instance not initialized; cannot submit write_weight_scale_to_buffer task.")
|
||||
|
||||
if not hasattr(self.moe, "write_weight_scale_to_buffer_task"):
|
||||
raise NotImplementedError(
|
||||
"write_weight_scale_to_buffer_task is not available for this backend implementation."
|
||||
)
|
||||
|
||||
self.cpu_infer.submit(
|
||||
self.moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count,
|
||||
gpu_experts_num,
|
||||
w13_weight_ptrs,
|
||||
w13_scale_ptrs,
|
||||
w2_weight_ptrs,
|
||||
w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
|
||||
def sync_write_weight_scale_to_buffer(self):
|
||||
"""
|
||||
Block until previously submitted write_weight_scale_to_buffer tasks finish.
|
||||
"""
|
||||
# The CPUInfer.sync() call blocks until pending tasks complete.
|
||||
self.cpu_infer.sync()
|
||||
|
|
|
|||
|
|
@ -237,6 +237,56 @@ class SafeTensorLoader:
|
|||
return name in self.tensor_file_map
|
||||
|
||||
|
||||
class CompressedSafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load raw expert weights stored in compressed safetensor format."""
|
||||
|
||||
experts_prefix = f"{base_key}.mlp.experts"
|
||||
|
||||
expert_idx = 0
|
||||
while self.has_tensor(f"{experts_prefix}.{expert_idx}.up_proj.weight_packed"):
|
||||
expert_idx += 1
|
||||
|
||||
if expert_idx == 0:
|
||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||
|
||||
def load_projection(proj_name: str):
|
||||
weight_entries = []
|
||||
scale_entries = []
|
||||
|
||||
for exp_id in range(expert_idx):
|
||||
weight_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_packed"
|
||||
scale_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_scale"
|
||||
|
||||
if not self.has_tensor(weight_key):
|
||||
raise KeyError(f"Missing tensor: {weight_key}")
|
||||
if not self.has_tensor(scale_key):
|
||||
raise KeyError(f"Missing tensor: {scale_key}")
|
||||
|
||||
weight_tensor = self.load_tensor(weight_key, device).contiguous()
|
||||
scale_tensor = self.load_tensor(scale_key, device).contiguous()
|
||||
|
||||
weight_entries.append(weight_tensor)
|
||||
scale_entries.append(scale_tensor)
|
||||
|
||||
return weight_entries, scale_entries
|
||||
|
||||
gate_weights, gate_scales = load_projection("gate")
|
||||
up_weights, up_scales = load_projection("up")
|
||||
down_weights, down_scales = load_projection("down")
|
||||
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
"gate_scale": gate_scales,
|
||||
"up_scale": up_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
|
||||
class GGUFLoader:
|
||||
"""
|
||||
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue