feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
This commit is contained in:
mrhaoxx 2026-04-08 23:11:00 +08:00
parent ddb957596f
commit f36699affd
84 changed files with 51278 additions and 623 deletions

View file

@ -51,6 +51,15 @@ kt_kernel_ext = _kt_kernel_ext
# Import main API
from .experts import KTMoEWrapper
def __getattr__(name):
if name == "AMXSFTMoEWrapper":
try:
from .sft.amx import AMXSFTMoEWrapper
return AMXSFTMoEWrapper
except (ImportError, AttributeError):
return None
raise AttributeError(f"module 'kt_kernel' has no attribute {name!r}")
# Read version from package metadata (preferred) or fallback to project root
try:
# Try to get version from installed package metadata (works in installed environment)
@ -82,4 +91,4 @@ except ImportError:
except ImportError:
__version__ = "0.4.3"
__all__ = ["KTMoEWrapper", "kt_kernel_ext", "__cpu_variant__", "__version__"]
__all__ = ["KTMoEWrapper", "AMXSFTMoEWrapper", "kt_kernel_ext", "__cpu_variant__", "__version__"]