[feat](kt-kernel): CPU-GPU experts sched (#1796)
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled

This commit is contained in:
Jianwei Dong 2026-01-16 17:01:15 +08:00 committed by GitHub
parent 6277da4c2b
commit 027832c590
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 687 additions and 62 deletions

View file

@ -503,16 +503,21 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) {
return GeneralMOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size);
}))
.def(py::init(
[](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int num_gpu_experts) {
GeneralMOEConfig cfg(expert_num, routed_expert_num, hidden_size, intermediate_size);
cfg.num_gpu_experts = num_gpu_experts;
return cfg;
}))
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size,
uintptr_t gpu_experts_mask_ptr) {
GeneralMOEConfig cfg(expert_num, routed_expert_num, hidden_size, intermediate_size);
cfg.gpu_experts_mask = reinterpret_cast<uint8_t*>(gpu_experts_mask_ptr);
cfg.compute_num_gpu_experts();
return cfg;
}))
.def_readwrite("layer_idx", &GeneralMOEConfig::layer_idx)
.def_readwrite("pool", &GeneralMOEConfig::pool)
.def_readwrite("num_gpu_experts", &GeneralMOEConfig::num_gpu_experts)
.def_readonly("num_gpu_experts", &GeneralMOEConfig::num_gpu_experts)
.def_property(
"gpu_experts_mask",
[](const GeneralMOEConfig& self) { return reinterpret_cast<uintptr_t>(self.gpu_experts_mask); },
[](GeneralMOEConfig& self, uintptr_t val) { self.gpu_experts_mask = reinterpret_cast<uint8_t*>(val); })
.DEF_PTR_PROPERTY(GeneralMOEConfig, physical_to_logical_map)
.DEF_PTR_PROPERTY(GeneralMOEConfig, gate_proj)