kvcache-ai-ktransformers/kt-kernel/python/experts.py
mrhaoxx 9544a8960d
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
feat(sft): AMX MoE SFT backend with LoRA support (#1936)
* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
2026-04-22 11:27:01 +08:00

391 lines
13 KiB
Python

# Wrapper for MoE CPU inference operations
# This module encapsulates CPU inference engine, weight loading, and buffer management
# SPDX-License-Identifier: Apache-2.0
"""
Expert wrappers for CPU-based MoE operations (inference and SFT).
This module provides the main factory interface (KTMoEWrapper) that automatically
selects the appropriate backend implementation based on the method and mode parameters.
Usage:
# Inference mode (default)
wrapper = KTMoEWrapper(..., mode="inference", method="AMXINT4")
# SFT mode
wrapper = KTMoEWrapper(..., mode="sft", method="AMXBF16_SFT", lora_rank=16)
"""
from __future__ import annotations
import torch
from typing import List, Optional, Union
# Import base infrastructure for inference
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
# Import inference backend implementations
from .utils.amx import AMXMoEWrapper, NativeMoEWrapper
from .utils.llamafile import LlamafileMoEWrapper
from .utils.moe_kernel import GeneralMoEWrapper
# Valid methods for each mode
INFERENCE_METHODS = frozenset(
[
"AMXINT4",
"AMXINT8", # AMX quantization
"RAWINT4",
"FP8", # Native quantization
"BF16", # BF16 native MoE
"FP8_PERCHANNEL", # Per-channel FP8
"GPTQ_INT4", # GPTQ INT4
"LLAMAFILE", # GGUF format
"MOE_INT4",
"MOE_INT8", # General kernel
]
)
SFT_METHODS = frozenset(
[
"AMXBF16_SFT", # AMX BF16 training
"AMXINT8_SFT", # AMX INT8 training
"AMXINT4_SFT", # AMX INT4 training
"AMXINT4_1_SFT", # AMX INT4_1 training
"AMXINT4_KGroup_SFT", # AMX INT4 K-Group training
"AMXINT4_1KGroup_SFT", # AMX INT4_1 K-Group training
# SkipLoRA variants (skip all LoRA computation in backward, only compute base weight grad_input)
"AMXBF16_SFT_SkipLoRA",
"AMXINT8_SFT_SkipLoRA",
"AMXINT4_SFT_SkipLoRA",
"AMXINT4_1_SFT_SkipLoRA",
"AMXINT4_KGroup_SFT_SkipLoRA",
"AMXINT4_1KGroup_SFT_SkipLoRA",
]
)
class KTMoEWrapper:
"""
Factory interface for MoE CPU operations (inference and SFT).
This class serves as the main entry point for external code. It automatically
selects the appropriate backend implementation based on the `mode` and `method` parameters.
Supported modes:
- "inference": Optimized for low-latency inference
- "sft": Supervised fine-tuning with LoRA adapters
Usage (Inference):
# Create a mask where experts 0, 2, 5 are on GPU
gpu_mask = torch.zeros(8, dtype=torch.bool)
gpu_mask[[0, 2, 5]] = True
wrapper = KTMoEWrapper(
layer_idx=0,
num_experts=8,
num_experts_per_tok=2,
hidden_size=4096,
moe_intermediate_size=14336,
gpu_experts_mask=gpu_mask, # or None for all experts on CPU
cpuinfer_threads=32,
threadpool_count=2,
weight_path="/path/to/weights",
chunked_prefill_size=25600,
method="AMXINT4", # or "AMXINT8", "LLAMAFILE"
mode="inference", # default
)
Usage (SFT):
wrapper = KTMoEWrapper(
layer_idx=0,
num_experts=256,
num_experts_per_tok=8,
hidden_size=7168,
moe_intermediate_size=2048,
num_gpu_experts=0,
cpuinfer_threads=60,
threadpool_count=4,
weight_path="/path/to/weights",
chunked_prefill_size=25600,
method="AMXBF16_SFT", # or "AMXINT8_SFT", "AMXINT4_SFT"
mode="sft",
lora_rank=16,
lora_alpha=32.0,
)
"""
def __new__(
cls,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
gpu_experts_mask: Optional[torch.Tensor],
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
# Inference-specific parameters
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
# Mode and method selection
method: str = "AMXINT4",
numa_nodes: Optional[List[int]] = None,
mode: str = "inference",
# SFT-specific parameters (only used when mode="sft")
num_gpu_experts: int = 0,
lora_rank: int = 16,
lora_alpha: float = 32.0,
max_cache_depth: int = 1,
# Quantization config (for K-Group SFT methods)
group_size: int = 128,
zero_point: bool = True,
):
"""
Factory method to create the appropriate backend implementation.
Args:
layer_idx: Layer index
num_experts: Total number of experts
num_experts_per_tok: Number of experts per token (top-k)
hidden_size: Hidden dimension size
moe_intermediate_size: MoE intermediate size
gpu_experts_mask: Boolean mask indicating which experts are on GPU (inference).
Shape: [num_experts], dtype: torch.bool.
mask[i] = True means expert i is on GPU.
If None, all experts are on CPU.
SFT mode uses num_gpu_experts instead.
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools (TP count)
weight_path: Path to weights
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory (inference only)
max_deferred_experts_per_token: Experts per token to defer (inference only)
numa_nodes: Explicit list of NUMA node IDs for subpool mapping. If None, defaults to sequential.
method: Backend method (see INFERENCE_METHODS and SFT_METHODS)
mode: Operation mode ("inference" or "sft")
lora_rank: LoRA rank (SFT only)
lora_alpha: LoRA scaling factor (SFT only)
max_cache_depth: Maximum forward cache depth (SFT only)
group_size: Quantization group size (SFT K-Group methods only)
zero_point: Use zero point quantization (SFT K-Group methods only)
Returns:
BaseMoEWrapper for inference mode, BaseSFTMoEWrapper for SFT mode
Raises:
ValueError: If mode is invalid or method doesn't match mode
"""
# Validate mode
if mode not in ("inference", "sft"):
raise ValueError(f"Unknown mode: '{mode}'. Supported modes: 'inference', 'sft'")
# Validate method matches mode
if mode == "inference":
if method not in INFERENCE_METHODS:
raise ValueError(
f"Method '{method}' not supported for inference mode. "
f"Supported methods: {sorted(INFERENCE_METHODS)}"
)
else: # mode == "sft"
if method not in SFT_METHODS:
raise ValueError(
f"Method '{method}' not supported for SFT mode. " f"Supported methods: {sorted(SFT_METHODS)}"
)
# Create appropriate backend
if mode == "inference":
return _create_inference_wrapper(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
gpu_experts_mask=gpu_experts_mask,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)
else: # mode == "sft"
return _create_sft_wrapper(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
method=method,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
max_cache_depth=max_cache_depth,
group_size=group_size,
zero_point=zero_point,
)
# Forward static methods to the base class
@staticmethod
def set_capture_batch_sizes(capture_bs: List[int]):
"""
Set batch sizes to capture and cache buffers for.
This allows pre-allocation of CPU buffers for specific batch sizes,
improving performance by avoiding buffer re-allocation during inference.
Args:
capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])
"""
BaseMoEWrapper.set_capture_batch_sizes(capture_bs)
@staticmethod
def get_capture_batch_sizes() -> List[int]:
"""
Get currently configured capture batch sizes.
Returns:
List of batch sizes that are being captured
"""
return BaseMoEWrapper.get_capture_batch_sizes()
@staticmethod
def clear_buffer_cache():
"""
Clear all cached buffers.
This frees up memory by clearing the buffer cache. Useful when you want
to reset the buffer state or free memory.
"""
BaseMoEWrapper.clear_buffer_cache()
@staticmethod
def clear_sft_buffer_cache():
"""
Clear all cached SFT buffers.
This frees up memory by clearing the SFT buffer cache. Useful when you want
to reset the buffer state or free memory during SFT.
"""
from .sft.base import KExpertsSFTBuffer
KExpertsSFTBuffer.clear_cache()
# =============================================================================
# Private helper functions for creating wrapper instances
# =============================================================================
def _create_inference_wrapper(
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
gpu_experts_mask: Optional[torch.Tensor],
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool,
max_deferred_experts_per_token: Optional[int],
method: str,
numa_nodes: Optional[List[int]] = None,
) -> BaseMoEWrapper:
"""
Create an inference wrapper based on the method.
Args:
See KTMoEWrapper.__new__ for parameter descriptions.
Returns:
BaseMoEWrapper instance
"""
# Select backend based on method
if method in ["AMXINT4", "AMXINT8"]:
backend_cls = AMXMoEWrapper
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4"]:
backend_cls = NativeMoEWrapper
elif method == "LLAMAFILE":
backend_cls = LlamafileMoEWrapper
elif method in ["MOE_INT4", "MOE_INT8"]:
backend_cls = GeneralMoEWrapper
else:
# This shouldn't happen due to validation in __new__
raise NotImplementedError(f"Unsupported inference method: {method}")
# Create and return backend instance
return backend_cls(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
gpu_experts_mask=gpu_experts_mask,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)
def _create_sft_wrapper(
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
method: str,
lora_rank: int,
lora_alpha: float,
max_cache_depth: int,
group_size: int,
zero_point: bool,
):
"""
Create an SFT wrapper based on the method.
Args:
See KTMoEWrapper.__new__ for parameter descriptions.
Returns:
BaseSFTMoEWrapper instance
"""
from .sft.amx import AMXSFTMoEWrapper
# Currently only AMX SFT methods are supported
return AMXSFTMoEWrapper(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
max_cache_depth=max_cache_depth,
method=method,
group_size=group_size,
zero_point=zero_point,
)