(kt-kernel): add numa_nodes parameter for explicit NUMA node mapping (#1891)

Add numa_nodes parameter to BaseMoEWrapper and all subclasses, allowing
users to explicitly specify which NUMA node IDs to use for subpool
mapping instead of always defaulting to sequential [0, 1, ..., N-1].

This enables running multiple KTransformers instances on different NUMA
nodes of the same machine, e.g. --kt-threadpool-count 1 --kt-numa-nodes 1
to bind to NUMA node 1. Previously this required external numactl
workarounds since subpool_numa_map was hardcoded to start from 0.
This commit is contained in:
ErvinXie 2026-03-31 10:27:50 +08:00 committed by GitHub
parent bdf4bb76c5
commit 3903c9afcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 34 additions and 6 deletions

View file

@ -1,7 +1,7 @@
import os
import torch
import ctypes
from typing import Optional
from typing import List, Optional
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
@ -53,6 +53,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
numa_nodes: Optional[List[int]] = None,
):
"""
Initialize AMX MoE Wrapper.
@ -103,6 +104,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)
# AMX-specific: Check if we should load merged safetensor weights
@ -288,7 +290,11 @@ class AMXMoEWrapper(BaseMoEWrapper):
moe_config.save = True
moe_config.load = False
base_key = f"model.layers.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
try:
w = self.safetensor_loader.load_experts(base_key)
except (ValueError, KeyError):
base_key = f"model.language_model.layers.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
@ -392,6 +398,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)
if NativeMoEWrapper._native_loader_instance is None:
@ -431,7 +438,12 @@ class NativeMoEWrapper(BaseMoEWrapper):
t0 = time.time()
base_key = f"model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
try:
weights = self.loader.load_experts(base_key)
except (ValueError, KeyError):
# For VL/multimodal models (e.g. Qwen3.5) with 'language_model' prefix
base_key = f"model.language_model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
t1 = time.time()
# Keep individual tensors instead of stacking - avoid expensive memory copy

View file

@ -1,5 +1,5 @@
import torch
from typing import Optional
from typing import List, Optional
import os
# Use relative imports for package structure
@ -133,6 +133,7 @@ class LlamafileMoEWrapper(BaseMoEWrapper):
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)
self.weights_to_keep = None

View file

@ -1,7 +1,7 @@
import os
import torch
import ctypes
from typing import Optional
from typing import List, Optional
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
@ -97,6 +97,7 @@ class GeneralMoEWrapper(BaseMoEWrapper):
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)
# moe-specific: Check if we should load merged safetensor weights