(kt-kernel): add numa_nodes parameter for explicit NUMA node mapping (#1891)

Add numa_nodes parameter to BaseMoEWrapper and all subclasses, allowing
users to explicitly specify which NUMA node IDs to use for subpool
mapping instead of always defaulting to sequential [0, 1, ..., N-1].

This enables running multiple KTransformers instances on different NUMA
nodes of the same machine, e.g. --kt-threadpool-count 1 --kt-numa-nodes 1
to bind to NUMA node 1. Previously this required external numactl
workarounds since subpool_numa_map was hardcoded to start from 0.
This commit is contained in:
ErvinXie 2026-03-31 10:27:50 +08:00 committed by GitHub
parent bdf4bb76c5
commit 3903c9afcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 34 additions and 6 deletions

View file

@ -164,6 +164,7 @@ class BaseMoEWrapper(ABC):
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
numa_nodes: Optional[List[int]] = None,
):
"""
Initialize base MoE Wrapper.
@ -185,6 +186,8 @@ class BaseMoEWrapper(ABC):
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
method: Backend method string
numa_nodes: Explicit list of NUMA node IDs for subpool mapping.
If None, defaults to [0, 1, ..., threadpool_count-1].
"""
self.layer_idx = layer_idx
self.num_experts = num_experts
@ -221,7 +224,15 @@ class BaseMoEWrapper(ABC):
if BaseMoEWrapper._cpu_infer_instance is None:
worker_config = kt_kernel_ext.WorkerPoolConfig()
subpool_numa_map = list(range(threadpool_count))
if numa_nodes is not None:
if len(numa_nodes) != threadpool_count:
raise ValueError(
f"numa_nodes length ({len(numa_nodes)}) must match "
f"threadpool_count ({threadpool_count})"
)
subpool_numa_map = list(numa_nodes)
else:
subpool_numa_map = list(range(threadpool_count))
subpool_thread_count = [
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
for i in range(threadpool_count)