mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 20:00:06 +00:00
Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
83 lines
2.3 KiB
Python
83 lines
2.3 KiB
Python
# SFT (Supervised Fine-Tuning) submodule for kt-kernel
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""
|
|
SFT training support for KT-Kernel MoE.
|
|
|
|
This submodule adds training capabilities (forward/backward, LoRA, autograd,
|
|
distributed) on top of the inference-only kt_kernel base package.
|
|
|
|
Additional dependencies beyond base kt_kernel: torch.nn, torch.distributed, peft (optional).
|
|
"""
|
|
|
|
from .config import KTConfig
|
|
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
|
|
from .amx import AMXSFTMoEWrapper
|
|
from .arch import (
|
|
MOEArchConfig, get_moe_arch_config, get_moe_module, move_non_experts_to_gpu, get_expert_device,
|
|
KTAMXError, KTAMXNotAvailableError, KTAMXModelNotSupportedError, KTAMXConfigError,
|
|
)
|
|
from .autograd import KTMoEFunction
|
|
from .layer import KTMoELayerWrapper
|
|
from .weights import (
|
|
extract_moe_weights,
|
|
load_experts_from_checkpoint_files,
|
|
load_experts_from_kt_weight_path,
|
|
INT8ExpertWeights,
|
|
)
|
|
from .lora import (
|
|
kt_adapt_peft_lora,
|
|
get_kt_lora_params,
|
|
update_kt_lora_pointers,
|
|
sync_kt_lora_gradients,
|
|
save_lora_experts_to_adapter,
|
|
save_kt_moe_to_adapter,
|
|
load_lora_experts_from_adapter,
|
|
load_kt_moe_from_adapter,
|
|
LoRAExpertMLP,
|
|
LoRAExperts,
|
|
)
|
|
from .wrapper import (
|
|
wrap_moe_layers_with_kt_wrapper,
|
|
build_kt_device_map,
|
|
build_kt_device_map_simplified,
|
|
get_kt_loading_kwargs,
|
|
load_kt_model,
|
|
)
|
|
|
|
__all__ = [
|
|
"KTConfig",
|
|
"BaseSFTMoEWrapper",
|
|
"KExpertsSFTBuffer",
|
|
"AMXSFTMoEWrapper",
|
|
"MOEArchConfig",
|
|
"get_moe_arch_config",
|
|
"get_moe_module",
|
|
"move_non_experts_to_gpu",
|
|
"get_expert_device",
|
|
"KTAMXError",
|
|
"KTAMXNotAvailableError",
|
|
"KTAMXModelNotSupportedError",
|
|
"KTAMXConfigError",
|
|
"KTMoEFunction",
|
|
"KTMoELayerWrapper",
|
|
"extract_moe_weights",
|
|
"load_experts_from_checkpoint_files",
|
|
"load_experts_from_kt_weight_path",
|
|
"INT8ExpertWeights",
|
|
"kt_adapt_peft_lora",
|
|
"get_kt_lora_params",
|
|
"update_kt_lora_pointers",
|
|
"sync_kt_lora_gradients",
|
|
"save_lora_experts_to_adapter",
|
|
"save_kt_moe_to_adapter",
|
|
"load_lora_experts_from_adapter",
|
|
"load_kt_moe_from_adapter",
|
|
"LoRAExpertMLP",
|
|
"LoRAExperts",
|
|
"wrap_moe_layers_with_kt_wrapper",
|
|
"build_kt_device_map",
|
|
"build_kt_device_map_simplified",
|
|
"get_kt_loading_kwargs",
|
|
"load_kt_model",
|
|
]
|