mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 20:29:48 +00:00
feat(sft): AMX MoE SFT backend with LoRA support
Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
This commit is contained in:
parent
ddb957596f
commit
f36699affd
84 changed files with 51278 additions and 623 deletions
|
|
@ -51,6 +51,15 @@ kt_kernel_ext = _kt_kernel_ext
|
|||
# Import main API
|
||||
from .experts import KTMoEWrapper
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "AMXSFTMoEWrapper":
|
||||
try:
|
||||
from .sft.amx import AMXSFTMoEWrapper
|
||||
return AMXSFTMoEWrapper
|
||||
except (ImportError, AttributeError):
|
||||
return None
|
||||
raise AttributeError(f"module 'kt_kernel' has no attribute {name!r}")
|
||||
|
||||
# Read version from package metadata (preferred) or fallback to project root
|
||||
try:
|
||||
# Try to get version from installed package metadata (works in installed environment)
|
||||
|
|
@ -82,4 +91,4 @@ except ImportError:
|
|||
except ImportError:
|
||||
__version__ = "0.4.3"
|
||||
|
||||
__all__ = ["KTMoEWrapper", "kt_kernel_ext", "__cpu_variant__", "__version__"]
|
||||
__all__ = ["KTMoEWrapper", "AMXSFTMoEWrapper", "kt_kernel_ext", "__cpu_variant__", "__version__"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue