mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
(kt-kernel): add numa_nodes parameter for explicit NUMA node mapping (#1891)
Add numa_nodes parameter to BaseMoEWrapper and all subclasses, allowing users to explicitly specify which NUMA node IDs to use for subpool mapping instead of always defaulting to sequential [0, 1, ..., N-1]. This enables running multiple KTransformers instances on different NUMA nodes of the same machine, e.g. --kt-threadpool-count 1 --kt-numa-nodes 1 to bind to NUMA node 1. Previously this required external numactl workarounds since subpool_numa_map was hardcoded to start from 0.
This commit is contained in:
parent
bdf4bb76c5
commit
3903c9afcc
5 changed files with 34 additions and 6 deletions
|
|
@ -65,6 +65,7 @@ class KTMoEWrapper:
|
|||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "AMXINT4",
|
||||
numa_nodes: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Factory method to create the appropriate backend implementation.
|
||||
|
|
@ -85,6 +86,7 @@ class KTMoEWrapper:
|
|||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
numa_nodes: Explicit list of NUMA node IDs for subpool mapping. If None, defaults to sequential.
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "BF16", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
|
||||
Returns:
|
||||
|
|
@ -117,6 +119,7 @@ class KTMoEWrapper:
|
|||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
)
|
||||
|
||||
# Forward static methods to the base class
|
||||
|
|
|
|||
|
|
@ -164,6 +164,7 @@ class BaseMoEWrapper(ABC):
|
|||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "AMXINT4",
|
||||
numa_nodes: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize base MoE Wrapper.
|
||||
|
|
@ -185,6 +186,8 @@ class BaseMoEWrapper(ABC):
|
|||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
|
||||
method: Backend method string
|
||||
numa_nodes: Explicit list of NUMA node IDs for subpool mapping.
|
||||
If None, defaults to [0, 1, ..., threadpool_count-1].
|
||||
"""
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = num_experts
|
||||
|
|
@ -221,7 +224,15 @@ class BaseMoEWrapper(ABC):
|
|||
if BaseMoEWrapper._cpu_infer_instance is None:
|
||||
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
|
||||
subpool_numa_map = list(range(threadpool_count))
|
||||
if numa_nodes is not None:
|
||||
if len(numa_nodes) != threadpool_count:
|
||||
raise ValueError(
|
||||
f"numa_nodes length ({len(numa_nodes)}) must match "
|
||||
f"threadpool_count ({threadpool_count})"
|
||||
)
|
||||
subpool_numa_map = list(numa_nodes)
|
||||
else:
|
||||
subpool_numa_map = list(range(threadpool_count))
|
||||
subpool_thread_count = [
|
||||
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
|
||||
for i in range(threadpool_count)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import torch
|
||||
import ctypes
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
|
|
@ -53,6 +53,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
|||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "AMXINT4",
|
||||
numa_nodes: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize AMX MoE Wrapper.
|
||||
|
|
@ -103,6 +104,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
|||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
)
|
||||
|
||||
# AMX-specific: Check if we should load merged safetensor weights
|
||||
|
|
@ -288,7 +290,11 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
|||
moe_config.save = True
|
||||
moe_config.load = False
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
try:
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
except (ValueError, KeyError):
|
||||
base_key = f"model.language_model.layers.{self.layer_idx}"
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
|
||||
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
|
||||
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
|
||||
|
|
@ -392,6 +398,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
)
|
||||
|
||||
if NativeMoEWrapper._native_loader_instance is None:
|
||||
|
|
@ -431,7 +438,12 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
|
||||
t0 = time.time()
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
weights = self.loader.load_experts(base_key)
|
||||
try:
|
||||
weights = self.loader.load_experts(base_key)
|
||||
except (ValueError, KeyError):
|
||||
# For VL/multimodal models (e.g. Qwen3.5) with 'language_model' prefix
|
||||
base_key = f"model.language_model.layers.{self.layer_idx}"
|
||||
weights = self.loader.load_experts(base_key)
|
||||
t1 = time.time()
|
||||
|
||||
# Keep individual tensors instead of stacking - avoid expensive memory copy
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
import os
|
||||
|
||||
# Use relative imports for package structure
|
||||
|
|
@ -133,6 +133,7 @@ class LlamafileMoEWrapper(BaseMoEWrapper):
|
|||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
)
|
||||
|
||||
self.weights_to_keep = None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import torch
|
||||
import ctypes
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
|
|
@ -97,6 +97,7 @@ class GeneralMoEWrapper(BaseMoEWrapper):
|
|||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
)
|
||||
|
||||
# moe-specific: Check if we should load merged safetensor weights
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue