mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
* support Kimi-K2-Thinking original weight fix amx kernel bug * update k2 avx kernel. * feat: add CPUInfer write buffer task * [feat]: add kimi k2 cpu write buffer support - Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights - Fix down (w2) weight column-wise slicing for different TP configurations - Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp - Add comprehensive test cases for weight extraction validation - Ensure compatibility with Kimi model's MoE architecture * [fix]: correct write_weight_scale_to_buffer expert offset calculation Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios. Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * [feat]: add write buffer wrapper * [fix] fix comment --------- Co-authored-by: ouqingliang <1692110604@qq.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
c2b8c60c4e
commit
fcf8882075
12 changed files with 2649 additions and 34 deletions
|
|
@ -36,6 +36,7 @@ static const bool _is_plain_ = false;
|
|||
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
#include "operators/amx/awq-moe.hpp"
|
||||
#include "operators/amx/k2-moe.hpp"
|
||||
#include "operators/amx/la/amx_kernels.hpp"
|
||||
#include "operators/amx/moe.hpp"
|
||||
#endif
|
||||
|
|
@ -43,6 +44,7 @@ static const bool _is_plain_ = false;
|
|||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "operators/kvcache/kvcache.h"
|
||||
#include "operators/llamafile/linear.h"
|
||||
|
|
@ -225,7 +227,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
|||
using MoeClass = TP_MOE<MoeTP>;
|
||||
using MoeBindings = MOEBindings<MoeTP>;
|
||||
|
||||
py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name)
|
||||
auto moe_cls = py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name);
|
||||
|
||||
moe_cls
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task",
|
||||
|
|
@ -244,6 +248,53 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
|||
.def("warm_up", &MoeClass::warm_up)
|
||||
.def("load_weights", &MoeClass::load_weights)
|
||||
.def("forward", &MoeClass::forward_binding);
|
||||
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
if constexpr (std::is_same_v<MoeTP, AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>) {
|
||||
struct WriteWeightScaleToBufferBindings {
|
||||
struct Args {
|
||||
CPUInfer* cpuinfer;
|
||||
MoeClass* moe;
|
||||
int gpu_tp_count;
|
||||
int gpu_experts_num;
|
||||
std::vector<uintptr_t> w13_weight_ptrs;
|
||||
std::vector<uintptr_t> w13_scale_ptrs;
|
||||
std::vector<uintptr_t> w2_weight_ptrs;
|
||||
std::vector<uintptr_t> w2_scale_ptrs;
|
||||
};
|
||||
|
||||
static void inner(void* args) {
|
||||
Args* args_ = (Args*)args;
|
||||
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe,
|
||||
args_->gpu_tp_count, args_->gpu_experts_num,
|
||||
args_->w13_weight_ptrs, args_->w13_scale_ptrs,
|
||||
args_->w2_weight_ptrs, args_->w2_scale_ptrs);
|
||||
}
|
||||
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe,
|
||||
int gpu_tp_count, int gpu_experts_num,
|
||||
py::list w13_weight_ptrs, py::list w13_scale_ptrs,
|
||||
py::list w2_weight_ptrs, py::list w2_scale_ptrs) {
|
||||
// Convert Python lists to std::vector<uintptr_t>
|
||||
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
|
||||
|
||||
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||
|
||||
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num,
|
||||
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
};
|
||||
|
||||
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
|
||||
py::arg("gpu_tp_count"), py::arg("gpu_experts_num"),
|
||||
py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
|
||||
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||
|
|
@ -513,6 +564,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
|||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
|
||||
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
|
||||
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
|
||||
#endif
|
||||
#if defined(USE_MOE_KERNEL)
|
||||
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue