mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 20:29:48 +00:00
Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
453 lines
16 KiB
Python
453 lines
16 KiB
Python
# Base classes for MoE CPU inference operations
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""
|
|
Base infrastructure for CPU-based MoE inference.
|
|
|
|
This module contains base classes and utilities shared across all backend implementations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
from typing import Dict, List, Optional, Tuple
|
|
from abc import ABC, abstractmethod
|
|
import os
|
|
import ctypes
|
|
|
|
from kt_kernel import kt_kernel_ext
|
|
|
|
|
|
class KExpertsCPUBuffer:
|
|
"""
|
|
CPU buffer management for expert computation.
|
|
|
|
Manages pinned memory buffers for efficient GPU-CPU data transfer.
|
|
"""
|
|
|
|
capture_bs: List = list()
|
|
capture_buffers: Dict = dict()
|
|
temp_bs: int = 0
|
|
temp_buffer: tuple = tuple()
|
|
buffer_depth: int = 2
|
|
|
|
@classmethod
|
|
def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):
|
|
hidden_size = hidden_states.shape[-1]
|
|
batch_size = hidden_states.shape[0]
|
|
|
|
pin_memory = False
|
|
|
|
if batch_size in cls.capture_buffers:
|
|
return cls.capture_buffers[batch_size]
|
|
if batch_size == cls.temp_bs:
|
|
return cls.temp_buffer
|
|
|
|
input_tensor_cpu = [
|
|
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=pin_memory, dtype=torch.bfloat16)
|
|
for _ in range(cls.buffer_depth)
|
|
]
|
|
immediate_experts_ids_cpu = [
|
|
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=pin_memory)
|
|
for _ in range(cls.buffer_depth)
|
|
]
|
|
deferred_experts_ids_cpu = [
|
|
torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=pin_memory)
|
|
for _ in range(cls.buffer_depth)
|
|
]
|
|
weights_cpu = [
|
|
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=pin_memory)
|
|
for _ in range(cls.buffer_depth)
|
|
]
|
|
output_cpu = [
|
|
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=pin_memory, dtype=torch.bfloat16)
|
|
for _ in range(cls.buffer_depth)
|
|
]
|
|
bsz_tensor_cpu = [
|
|
torch.full((1,), batch_size, device="cpu", dtype=torch.int32, pin_memory=pin_memory)
|
|
for _ in range(cls.buffer_depth)
|
|
]
|
|
output_gpu = [
|
|
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
|
|
for _ in range(cls.buffer_depth)
|
|
]
|
|
|
|
cur_buffer = (
|
|
input_tensor_cpu,
|
|
immediate_experts_ids_cpu,
|
|
deferred_experts_ids_cpu,
|
|
weights_cpu,
|
|
output_cpu,
|
|
bsz_tensor_cpu,
|
|
output_gpu,
|
|
)
|
|
if batch_size in cls.capture_bs:
|
|
cls.capture_buffers[batch_size] = cur_buffer
|
|
cls.temp_bs = batch_size
|
|
cls.temp_buffer = cur_buffer
|
|
return cur_buffer
|
|
|
|
|
|
class _MoEBase:
|
|
"""
|
|
Shared base class for inference and SFT MoE wrappers.
|
|
|
|
Provides:
|
|
- CPUInfer singleton management
|
|
- Basic configuration validation
|
|
|
|
This class is shared between BaseMoEWrapper (inference) and BaseSFTMoEWrapper (SFT).
|
|
"""
|
|
|
|
_cpu_infer_instance = None
|
|
|
|
@classmethod
|
|
def _get_cpu_infer(
|
|
cls,
|
|
cpuinfer_threads: int,
|
|
threadpool_count: int,
|
|
):
|
|
"""
|
|
Get or create the CPUInfer singleton instance.
|
|
|
|
Args:
|
|
cpuinfer_threads: Total number of CPU inference threads
|
|
threadpool_count: Number of NUMA subpools (TP count)
|
|
|
|
Returns:
|
|
CPUInfer singleton instance
|
|
"""
|
|
if cls._cpu_infer_instance is None:
|
|
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
|
|
|
subpool_numa_map = list(range(threadpool_count))
|
|
subpool_thread_count = [
|
|
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
|
|
for i in range(threadpool_count)
|
|
]
|
|
|
|
worker_config.subpool_count = threadpool_count
|
|
worker_config.subpool_numa_map = subpool_numa_map
|
|
worker_config.subpool_thread_count = subpool_thread_count
|
|
cls._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)
|
|
|
|
return cls._cpu_infer_instance
|
|
|
|
@staticmethod
|
|
def _validate_base_config(
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
moe_intermediate_size: int,
|
|
num_experts_per_tok: int,
|
|
) -> None:
|
|
"""
|
|
Validate basic configuration parameters.
|
|
|
|
Raises:
|
|
ValueError: If parameters are invalid
|
|
"""
|
|
if num_experts <= 0:
|
|
raise ValueError(f"num_experts must be positive, got {num_experts}")
|
|
if hidden_size <= 0:
|
|
raise ValueError(f"hidden_size must be positive, got {hidden_size}")
|
|
if moe_intermediate_size <= 0:
|
|
raise ValueError(f"moe_intermediate_size must be positive, got {moe_intermediate_size}")
|
|
if num_experts_per_tok <= 0:
|
|
raise ValueError(f"num_experts_per_tok must be positive, got {num_experts_per_tok}")
|
|
if num_experts_per_tok > num_experts:
|
|
raise ValueError(
|
|
f"num_experts_per_tok ({num_experts_per_tok}) cannot exceed " f"num_experts ({num_experts})"
|
|
)
|
|
|
|
|
|
class BaseMoEWrapper(_MoEBase, ABC):
|
|
"""
|
|
Base class for MoE CPU inference operations.
|
|
Provides common functionality for all backend implementations.
|
|
"""
|
|
|
|
_layer_has_pending_deferred: Dict[int, bool] = {}
|
|
|
|
def __init__(
|
|
self,
|
|
layer_idx: int,
|
|
num_experts: int,
|
|
num_experts_per_tok: int,
|
|
hidden_size: int,
|
|
moe_intermediate_size: int,
|
|
num_gpu_experts: int,
|
|
cpuinfer_threads: int,
|
|
threadpool_count: int,
|
|
weight_path: str,
|
|
chunked_prefill_size: int,
|
|
cpu_save: bool = False,
|
|
max_deferred_experts_per_token: Optional[int] = None,
|
|
method: str = "AMXINT4",
|
|
):
|
|
"""
|
|
Initialize base MoE Wrapper.
|
|
|
|
Args:
|
|
layer_idx: Layer index
|
|
num_experts: Total number of experts
|
|
num_experts_per_tok: Number of experts per token (top-k)
|
|
hidden_size: Hidden dimension size
|
|
moe_intermediate_size: MoE intermediate size
|
|
num_gpu_experts: Number of experts to run on GPU
|
|
cpuinfer_threads: Number of CPU inference threads
|
|
threadpool_count: Number of NUMA subpools
|
|
weight_path: Path to weights
|
|
chunked_prefill_size: Maximum prefill chunk size
|
|
cpu_save: Whether to save weights to CPU memory
|
|
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
|
|
method: Backend method string
|
|
"""
|
|
self.layer_idx = layer_idx
|
|
self.num_experts = num_experts
|
|
self.num_experts_per_tok = num_experts_per_tok
|
|
self.hidden_size = hidden_size
|
|
self.moe_intermediate_size = moe_intermediate_size
|
|
self.num_gpu_experts = num_gpu_experts
|
|
self.weight_path = weight_path
|
|
self.chunked_prefill_size = chunked_prefill_size
|
|
self.cpu_save = cpu_save
|
|
self.max_deferred_experts_per_token = (
|
|
int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
|
|
)
|
|
|
|
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
|
self.method = method
|
|
|
|
# Initialize CPU inference engine (singleton via shared base class)
|
|
self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count)
|
|
|
|
# Backend-specific initialization happens in subclasses
|
|
self.moe = None
|
|
|
|
@abstractmethod
|
|
def load_weights_from_tensors(
|
|
self,
|
|
gate_proj: torch.Tensor,
|
|
up_proj: torch.Tensor,
|
|
down_proj: torch.Tensor,
|
|
physical_to_logical_map_cpu: torch.Tensor,
|
|
):
|
|
"""
|
|
Load and quantize weights from BF16/FP16 tensors (online quantization).
|
|
|
|
Args:
|
|
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
|
|
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
|
|
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
|
|
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
|
"""
|
|
Load weights for this layer and initialize the MoE module.
|
|
|
|
Args:
|
|
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
|
"""
|
|
pass
|
|
|
|
def select_deferred_experts(
|
|
self,
|
|
expert_ids: torch.Tensor,
|
|
expert_scores: torch.Tensor,
|
|
protected_k: int,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
batch, topk = expert_ids.shape
|
|
device = expert_ids.device
|
|
|
|
protected_k = max(0, min(int(protected_k), topk))
|
|
if protected_k == 0:
|
|
deferred_ids = expert_ids.clone()
|
|
immediate_ids = torch.full_like(expert_ids, -1)
|
|
return immediate_ids, deferred_ids
|
|
|
|
topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)
|
|
protected_indices = topk_result.indices
|
|
protected_ids = torch.gather(expert_ids, -1, protected_indices)
|
|
|
|
protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)
|
|
protected_flag.scatter_(0, protected_ids.reshape(-1), 1)
|
|
|
|
protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)
|
|
protected_mask = protected_mask_flat.view(batch, topk)
|
|
|
|
immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)
|
|
deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)
|
|
|
|
return immediate_ids, deferred_ids
|
|
|
|
def submit_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
cuda_stream,
|
|
):
|
|
"""
|
|
Submit forward inference task to CPU (non-blocking).
|
|
|
|
Args:
|
|
hidden_states: Input hidden states [batch_size, hidden_size]
|
|
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
|
|
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
|
|
cuda_stream: CUDA stream for synchronization
|
|
"""
|
|
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
batch_size = flat_hidden_states.shape[0]
|
|
|
|
(
|
|
input_tensor_cpu,
|
|
immediate_experts_ids_cpu,
|
|
deferred_experts_ids_cpu,
|
|
weights_cpu,
|
|
output_cpu,
|
|
bsz_tensor_cpu,
|
|
_output_gpu,
|
|
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
|
|
|
|
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
|
|
next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth
|
|
|
|
bsz_slot_tensor = bsz_tensor_cpu[current_slot]
|
|
|
|
topk_ids_long = topk_ids.to(torch.long)
|
|
immediate_ids: torch.Tensor
|
|
deferred_ids: Optional[torch.Tensor]
|
|
if self.max_deferred_experts_per_token > 0:
|
|
protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token
|
|
|
|
immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)
|
|
else:
|
|
immediate_ids = topk_ids_long
|
|
deferred_ids = None
|
|
|
|
input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)
|
|
weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)
|
|
immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)
|
|
|
|
incremental = BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)
|
|
self.cpu_infer.submit_with_cuda_stream(
|
|
cuda_stream,
|
|
self.moe.forward_task(
|
|
bsz_slot_tensor.data_ptr(),
|
|
immediate_experts_ids_cpu[current_slot].size(-1),
|
|
immediate_experts_ids_cpu[current_slot].data_ptr(),
|
|
weights_cpu[current_slot].data_ptr(),
|
|
input_tensor_cpu[current_slot].data_ptr(),
|
|
output_cpu[current_slot].data_ptr(),
|
|
incremental,
|
|
),
|
|
)
|
|
|
|
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
|
if deferred_ids is not None:
|
|
deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)
|
|
self.cpu_infer.submit_with_cuda_stream(
|
|
cuda_stream,
|
|
self.moe.forward_task(
|
|
bsz_slot_tensor.data_ptr(),
|
|
deferred_experts_ids_cpu[current_slot].size(-1),
|
|
deferred_experts_ids_cpu[current_slot].data_ptr(),
|
|
weights_cpu[current_slot].data_ptr(),
|
|
input_tensor_cpu[current_slot].data_ptr(),
|
|
output_cpu[next_slot].data_ptr(),
|
|
False,
|
|
),
|
|
)
|
|
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True
|
|
|
|
def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:
|
|
"""
|
|
Synchronize and retrieve forward inference results.
|
|
|
|
Args:
|
|
hidden_states: Original input hidden states (for getting buffer)
|
|
cuda_stream: CUDA stream for synchronization
|
|
|
|
Returns:
|
|
output_gpu: Output tensor on GPU
|
|
"""
|
|
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
(
|
|
_input_tensor_cpu,
|
|
_immediate_experts_ids_cpu,
|
|
_deferred_experts_ids_cpu,
|
|
_weights_cpu,
|
|
output_cpu,
|
|
_bsz_tensor_cpu,
|
|
output_gpu,
|
|
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
|
|
|
|
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
|
|
allow_pending = 1 if BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0
|
|
self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)
|
|
output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)
|
|
return output_gpu[current_slot]
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
cuda_stream,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Execute forward inference synchronously (submit + sync).
|
|
|
|
Args:
|
|
hidden_states: Input hidden states [batch_size, hidden_size]
|
|
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
|
|
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
|
|
cuda_stream: CUDA stream for synchronization
|
|
|
|
Returns:
|
|
Output tensor on GPU
|
|
"""
|
|
self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)
|
|
return self.sync_forward(hidden_states, cuda_stream)
|
|
|
|
@staticmethod
|
|
def set_capture_batch_sizes(capture_bs: List[int]):
|
|
"""
|
|
Set batch sizes to capture and cache buffers for.
|
|
|
|
This allows pre-allocation of CPU buffers for specific batch sizes,
|
|
improving performance by avoiding buffer re-allocation during inference.
|
|
|
|
Args:
|
|
capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])
|
|
|
|
Example:
|
|
>>> BaseMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])
|
|
"""
|
|
KExpertsCPUBuffer.capture_bs = capture_bs
|
|
|
|
@staticmethod
|
|
def get_capture_batch_sizes() -> List[int]:
|
|
"""
|
|
Get currently configured capture batch sizes.
|
|
|
|
Returns:
|
|
List of batch sizes that are being captured
|
|
"""
|
|
return KExpertsCPUBuffer.capture_bs
|
|
|
|
@staticmethod
|
|
def clear_buffer_cache():
|
|
"""
|
|
Clear all cached buffers.
|
|
|
|
This frees up memory by clearing the buffer cache. Useful when you want
|
|
to reset the buffer state or free memory.
|
|
"""
|
|
KExpertsCPUBuffer.capture_buffers.clear()
|
|
KExpertsCPUBuffer.temp_bs = 0
|
|
KExpertsCPUBuffer.temp_buffer = tuple()
|
|
|