[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

* support Kimi-K2-Thinking original weight
fix amx kernel bug

* update k2 avx kernel.

* feat: add CPUInfer write buffer task

* [feat]: add kimi k2 cpu write buffer support

- Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights
- Fix down (w2) weight column-wise slicing for different TP configurations
- Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp
- Add comprehensive test cases for weight extraction validation
- Ensure compatibility with Kimi model's MoE architecture

* [fix]: correct write_weight_scale_to_buffer expert offset calculation

Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios.

Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* [feat]: add write buffer wrapper

* [fix] fix comment

---------

Co-authored-by: ouqingliang <1692110604@qq.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Jiaqi Liao 2025-12-02 16:01:07 +08:00 committed by GitHub
parent c2b8c60c4e
commit fcf8882075
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 2649 additions and 34 deletions

View file

@ -4,16 +4,16 @@ import ctypes
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader
from .loader import SafeTensorLoader, CompressedSafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
_HAS_AMX_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE = None, None
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
from typing import Optional
@ -301,3 +301,152 @@ class AMXMoEWrapper(BaseMoEWrapper):
del self.gate_scales
del self.up_scales
del self.down_scales
class RAWAMXMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
_compressed_loader_instance = None
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "RAWINT4",
):
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
)
if RAWAMXMoEWrapper._compressed_loader_instance is None:
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
self.gate_weights = None
self.up_weights = None
self.down_weights = None
self.gate_scales = None
self.up_scales = None
self.down_scales = None
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
base_key = f"model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = 32
moe_config.quant_config.zero_point = False
moe_config.gate_proj = self.gate_weights.data_ptr()
moe_config.up_proj = self.up_weights.data_ptr()
moe_config.down_proj = self.down_weights.data_ptr()
moe_config.gate_scale = self.gate_scales.data_ptr()
moe_config.up_scale = self.up_scales.data_ptr()
moe_config.down_scale = self.down_scales.data_ptr()
self.moe = AMXInt4_KGroup_MOE(moe_config)
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
del self.gate_weights
del self.up_weights
del self.down_weights
del self.gate_scales
del self.up_scales
del self.down_scales
def submit_write_weight_scale_to_buffer(
self,
gpu_tp_count: int,
gpu_experts_num: int,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,
w2_scale_ptrs,
):
"""
Submit the write_weight_scale_to_buffer task for RAWINT4 KGroup AMX implementation.
This method submits the C++-exposed task `write_weight_scale_to_buffer_task` to the
shared CPUInfer queue. The pointer lists should be plain integer lists (e.g. from
tensor.data_ptr()).
"""
if self.moe is None:
raise RuntimeError("MoE instance not initialized; cannot submit write_weight_scale_to_buffer task.")
if not hasattr(self.moe, "write_weight_scale_to_buffer_task"):
raise NotImplementedError(
"write_weight_scale_to_buffer_task is not available for this backend implementation."
)
self.cpu_infer.submit(
self.moe.write_weight_scale_to_buffer_task(
gpu_tp_count,
gpu_experts_num,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,
w2_scale_ptrs,
)
)
def sync_write_weight_scale_to_buffer(self):
"""
Block until previously submitted write_weight_scale_to_buffer tasks finish.
"""
# The CPUInfer.sync() call blocks until pending tasks complete.
self.cpu_infer.sync()