[feat](kt-kernel): CPU-GPU experts sched (#1796)
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled

This commit is contained in:
Jianwei Dong 2026-01-16 17:01:15 +08:00 committed by GitHub
parent 6277da4c2b
commit 027832c590
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 687 additions and 62 deletions

View file

@ -1,6 +1,7 @@
import os
import torch
import ctypes
from typing import Optional
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
@ -40,7 +41,7 @@ class GeneralMoEWrapper(BaseMoEWrapper):
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
gpu_experts_mask: Optional[torch.Tensor],
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
@ -58,7 +59,10 @@ class GeneralMoEWrapper(BaseMoEWrapper):
num_experts_per_tok: Number of experts per token (top-k)
hidden_size: Hidden dimension size
moe_intermediate_size: MoE intermediate size
num_gpu_experts: Number of experts to run on GPU
gpu_experts_mask: Boolean mask indicating which experts are on GPU.
Shape: [num_experts], dtype: torch.bool.
mask[i] = True means expert i is on GPU.
If None, all experts are on CPU.
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools
weight_path: Path to weights (SafeTensor format)
@ -85,7 +89,7 @@ class GeneralMoEWrapper(BaseMoEWrapper):
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
gpu_experts_mask=gpu_experts_mask,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
@ -143,7 +147,7 @@ class GeneralMoEWrapper(BaseMoEWrapper):
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
self.gpu_experts_mask.data_ptr(),
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
@ -258,7 +262,7 @@ class GeneralMoEWrapper(BaseMoEWrapper):
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
self.gpu_experts_mask.data_ptr(),
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_