[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

* support Kimi-K2-Thinking original weight
fix amx kernel bug

* update k2 avx kernel.

* feat: add CPUInfer write buffer task

* [feat]: add kimi k2 cpu write buffer support

- Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights
- Fix down (w2) weight column-wise slicing for different TP configurations
- Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp
- Add comprehensive test cases for weight extraction validation
- Ensure compatibility with Kimi model's MoE architecture

* [fix]: correct write_weight_scale_to_buffer expert offset calculation

Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios.

Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* [feat]: add write buffer wrapper

* [fix] fix comment

---------

Co-authored-by: ouqingliang <1692110604@qq.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Jiaqi Liao 2025-12-02 16:01:07 +08:00 committed by GitHub
parent c2b8c60c4e
commit fcf8882075
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 2649 additions and 34 deletions

View file

@ -36,6 +36,7 @@ static const bool _is_plain_ = false;
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp"
#include "operators/amx/k2-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
#endif
@ -43,6 +44,7 @@ static const bool _is_plain_ = false;
#include <cstdint>
#include <memory>
#include <type_traits>
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
@ -225,7 +227,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
using MoeClass = TP_MOE<MoeTP>;
using MoeBindings = MOEBindings<MoeTP>;
py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name)
auto moe_cls = py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name);
moe_cls
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
@ -244,6 +248,53 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
.def("warm_up", &MoeClass::warm_up)
.def("load_weights", &MoeClass::load_weights)
.def("forward", &MoeClass::forward_binding);
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
if constexpr (std::is_same_v<MoeTP, AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>) {
struct WriteWeightScaleToBufferBindings {
struct Args {
CPUInfer* cpuinfer;
MoeClass* moe;
int gpu_tp_count;
int gpu_experts_num;
std::vector<uintptr_t> w13_weight_ptrs;
std::vector<uintptr_t> w13_scale_ptrs;
std::vector<uintptr_t> w2_weight_ptrs;
std::vector<uintptr_t> w2_scale_ptrs;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe,
args_->gpu_tp_count, args_->gpu_experts_num,
args_->w13_weight_ptrs, args_->w13_scale_ptrs,
args_->w2_weight_ptrs, args_->w2_scale_ptrs);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe,
int gpu_tp_count, int gpu_experts_num,
py::list w13_weight_ptrs, py::list w13_scale_ptrs,
py::list w2_weight_ptrs, py::list w2_scale_ptrs) {
// Convert Python lists to std::vector<uintptr_t>
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num,
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
py::arg("gpu_tp_count"), py::arg("gpu_experts_num"),
py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
#endif
}
PYBIND11_MODULE(kt_kernel_ext, m) {
@ -513,6 +564,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
#endif
#if defined(USE_MOE_KERNEL)
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");