feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
This commit is contained in:
mrhaoxx 2026-04-08 23:11:00 +08:00
parent ddb957596f
commit f36699affd
84 changed files with 51278 additions and 623 deletions

View file

@ -164,8 +164,12 @@ class SafeTensorLoader:
return tensor.to(device)
def close_all_handles(self):
for handle in self.file_handle_map.values():
handle.close()
"""Close all file handles and clear the handle map.
Note: safetensors.safe_open doesn't have a close() method,
so we just clear the references and let garbage collection handle cleanup.
"""
# safetensors.safe_open doesn't have close(), just clear references
self.file_handle_map.clear()
def load_experts(self, base_key: str, device: str = "cpu"):
@ -202,6 +206,20 @@ class SafeTensorLoader:
up_scales = [[] for _ in range(max_numa_id + 1)]
gate_scales = [[] for _ in range(max_numa_id + 1)]
down_scales = [[] for _ in range(max_numa_id + 1)]
# Check if backward weights exist
up_bwd_base_key = f"{base_key}.ffn_up_bwd_exps"
gate_bwd_base_key = f"{base_key}.ffn_gate_bwd_exps"
down_bwd_base_key = f"{base_key}.ffn_down_bwd_exps"
has_bwd = self.has_tensor(f"{gate_bwd_base_key}.{0}.numa.{0}.weight")
if has_bwd:
up_bwd_weights = [[] for _ in range(max_numa_id + 1)]
gate_bwd_weights = [[] for _ in range(max_numa_id + 1)]
down_bwd_weights = [[] for _ in range(max_numa_id + 1)]
up_bwd_scales = [[] for _ in range(max_numa_id + 1)]
gate_bwd_scales = [[] for _ in range(max_numa_id + 1)]
down_bwd_scales = [[] for _ in range(max_numa_id + 1)]
for numa_id in range(max_numa_id + 1):
for expert_id in range(max_experts_count + 1):
up_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.weight"
@ -224,7 +242,29 @@ class SafeTensorLoader:
up_scales[numa_id].append(up_scale_tensor)
gate_scales[numa_id].append(gate_scale_tensor)
down_scales[numa_id].append(down_scale_tensor)
return {
# Load backward weights if available
if has_bwd:
gate_bwd_weights[numa_id].append(
self.load_tensor(f"{gate_bwd_base_key}.{expert_id}.numa.{numa_id}.weight", device).numpy()
)
up_bwd_weights[numa_id].append(
self.load_tensor(f"{up_bwd_base_key}.{expert_id}.numa.{numa_id}.weight", device).numpy()
)
down_bwd_weights[numa_id].append(
self.load_tensor(f"{down_bwd_base_key}.{expert_id}.numa.{numa_id}.weight", device).numpy()
)
gate_bwd_scales[numa_id].append(
self.load_tensor(f"{gate_bwd_base_key}.{expert_id}.numa.{numa_id}.scale", device).numpy()
)
up_bwd_scales[numa_id].append(
self.load_tensor(f"{up_bwd_base_key}.{expert_id}.numa.{numa_id}.scale", device).numpy()
)
down_bwd_scales[numa_id].append(
self.load_tensor(f"{down_bwd_base_key}.{expert_id}.numa.{numa_id}.scale", device).numpy()
)
result = {
"up": up_weights,
"gate": gate_weights,
"down": down_weights,
@ -232,6 +272,14 @@ class SafeTensorLoader:
"gate_scale": gate_scales,
"down_scale": down_scales,
}
if has_bwd:
result["gate_bwd"] = gate_bwd_weights
result["up_bwd"] = up_bwd_weights
result["down_bwd"] = down_bwd_weights
result["gate_bwd_scale"] = gate_bwd_scales
result["up_bwd_scale"] = up_bwd_scales
result["down_bwd_scale"] = down_bwd_scales
return result
def has_tensor(self, name: str):
return name in self.tensor_file_map
@ -398,6 +446,111 @@ class CompressedSafeTensorLoader(SafeTensorLoader):
}
class BF16SafeTensorLoader(SafeTensorLoader):
"""Loader for native BF16 expert weights (no quantization, no scales).
Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
The format is auto-detected during initialization.
"""
MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str):
super().__init__(file_path)
self._detected_format = None
self._detect_format()
def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
sample_keys = list(self.tensor_file_map.keys())[:1000]
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
for key in sample_keys:
if ".experts." in key and f".{gate}.weight" in key:
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
self._detected_format = "deepseek"
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)
def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
return gate, up, down
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if device == "cpu":
return tensor
return tensor.to(device)
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load BF16 expert weights (no scales needed).
Args:
base_key: Base key like "model.layers.{layer_index}"
device: Target device for tensors
Returns:
Dictionary with keys: gate, up, down, gate_scale (None), up_scale (None), down_scale (None)
gate/up/down: list of tensors [expert_id] -> tensor
"""
experts_prefix = self._get_experts_prefix(base_key)
gate_name, up_name, down_name = self._get_proj_names()
expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count
for exp_id in range(expert_count):
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
"gate_scale": None,
"up_scale": None,
"down_scale": None,
}
class GGUFLoader:
"""
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)