mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 20:00:06 +00:00
feat(sft): AMX MoE SFT backend with LoRA support
Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
This commit is contained in:
parent
ddb957596f
commit
f36699affd
84 changed files with 51278 additions and 623 deletions
83
kt-kernel/python/sft/__init__.py
Normal file
83
kt-kernel/python/sft/__init__.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
# SFT (Supervised Fine-Tuning) submodule for kt-kernel
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
SFT training support for KT-Kernel MoE.
|
||||
|
||||
This submodule adds training capabilities (forward/backward, LoRA, autograd,
|
||||
distributed) on top of the inference-only kt_kernel base package.
|
||||
|
||||
Additional dependencies beyond base kt_kernel: torch.nn, torch.distributed, peft (optional).
|
||||
"""
|
||||
|
||||
from .config import KTConfig
|
||||
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
|
||||
from .amx import AMXSFTMoEWrapper
|
||||
from .arch import (
|
||||
MOEArchConfig, get_moe_arch_config, get_moe_module, move_non_experts_to_gpu, get_expert_device,
|
||||
KTAMXError, KTAMXNotAvailableError, KTAMXModelNotSupportedError, KTAMXConfigError,
|
||||
)
|
||||
from .autograd import KTMoEFunction
|
||||
from .layer import KTMoELayerWrapper
|
||||
from .weights import (
|
||||
extract_moe_weights,
|
||||
load_experts_from_checkpoint_files,
|
||||
load_experts_from_kt_weight_path,
|
||||
INT8ExpertWeights,
|
||||
)
|
||||
from .lora import (
|
||||
kt_adapt_peft_lora,
|
||||
get_kt_lora_params,
|
||||
update_kt_lora_pointers,
|
||||
sync_kt_lora_gradients,
|
||||
save_lora_experts_to_adapter,
|
||||
save_kt_moe_to_adapter,
|
||||
load_lora_experts_from_adapter,
|
||||
load_kt_moe_from_adapter,
|
||||
LoRAExpertMLP,
|
||||
LoRAExperts,
|
||||
)
|
||||
from .wrapper import (
|
||||
wrap_moe_layers_with_kt_wrapper,
|
||||
build_kt_device_map,
|
||||
build_kt_device_map_simplified,
|
||||
get_kt_loading_kwargs,
|
||||
load_kt_model,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KTConfig",
|
||||
"BaseSFTMoEWrapper",
|
||||
"KExpertsSFTBuffer",
|
||||
"AMXSFTMoEWrapper",
|
||||
"MOEArchConfig",
|
||||
"get_moe_arch_config",
|
||||
"get_moe_module",
|
||||
"move_non_experts_to_gpu",
|
||||
"get_expert_device",
|
||||
"KTAMXError",
|
||||
"KTAMXNotAvailableError",
|
||||
"KTAMXModelNotSupportedError",
|
||||
"KTAMXConfigError",
|
||||
"KTMoEFunction",
|
||||
"KTMoELayerWrapper",
|
||||
"extract_moe_weights",
|
||||
"load_experts_from_checkpoint_files",
|
||||
"load_experts_from_kt_weight_path",
|
||||
"INT8ExpertWeights",
|
||||
"kt_adapt_peft_lora",
|
||||
"get_kt_lora_params",
|
||||
"update_kt_lora_pointers",
|
||||
"sync_kt_lora_gradients",
|
||||
"save_lora_experts_to_adapter",
|
||||
"save_kt_moe_to_adapter",
|
||||
"load_lora_experts_from_adapter",
|
||||
"load_kt_moe_from_adapter",
|
||||
"LoRAExpertMLP",
|
||||
"LoRAExperts",
|
||||
"wrap_moe_layers_with_kt_wrapper",
|
||||
"build_kt_device_map",
|
||||
"build_kt_device_map_simplified",
|
||||
"get_kt_loading_kwargs",
|
||||
"load_kt_model",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue