mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 20:00:06 +00:00
feat(sft): AMX MoE SFT backend with LoRA support (#1936)
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
* feat(sft): AMX MoE SFT backend with LoRA support Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally. * refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code - KTConfig fields all use kt_ prefix matching dict keys — eliminates _OLD_TO_NEW mapping and prefix-stripping in wrapper.py - Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing is on (via training_args.py), flows through to C++ cache allocation - Remove dead checkpoint detection code: in_ckpt_recompute, in_ckpt_first_forward vars (assigned but never read), fallback _is_in_checkpoint_first_forward() function, unused inspect import - Remove redundant env var fallbacks in wrapper.py for share_backward_bb and share_cache_pool (KTConfig.__post_init__ already handles env vars) - Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check Verified: Qwen3-235B 3-step training on sap4, loss matches baseline (1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor(sft): share_backward_bb default True, share_cache_pool auto-derived - kt_share_backward_bb defaults to True (always saves memory) - kt_share_cache_pool no longer reads from env var; defaults False, auto-set to True by trainer_config_process when gradient checkpointing is enabled Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument, but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(sft): support transformers v5 fused expert format Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters (gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert nn.Linear modules. PEFT cannot attach LoRA to these, so we create KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers for the optimizer, and pre-assigned .grad for C++ backward. - arch.py: detect_fused_experts() detection - weights.py: fused format extraction and weight clearing - wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank - lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA, get_kt_lora_params collects fused params, deduplicate wrapper finding - layer.py: handle v5 TopKRouter tuple output, remove dead code - autograd.py: sync_forward_sft/submit_forward_sft API rename Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(sft): add Qwen3.5 MoE support + fused checkpoint loading - arch.py: add Qwen3_5Moe arch match, read config from text_config, _get_layers_prefix returns model.language_model.layers for Qwen3.5, _get_model_container_and_layers searches language_model attr - weights.py: load_experts_from_checkpoint_files detects fused format (gate_up_proj in weight_map) and splits into gate/up/down - wrapper.py: hidden_size fallback to text_config Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * [fix](sft): align Python API with C++ backend after v5 refactor - wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature) - layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward - autograd.py: rename sync_forward_sft to sync_forward The sft-v5 refactor (commits58d7eab,dd1da65) renamed Python-side method calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original method names. This caused AttributeError on Qwen3.5-35B and other models. * align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults - Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace, sft_timer namespace, ITT API, extended do_work_stealing_job API) - Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp, moe-sft-tp.hpp, avx_kernels.hpp) - Restore pin_memory=True in KExpertsCPUBuffer (inference path) - Restore fused tensor transpose logic in convert_cpu_weights.py (main layout) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * revert CMakeLists.txt to main: remove debug flags and cpptrace dep Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * clean up dev artifacts: remove SFT design docs, debug examples, bench scripts Remove files not needed in the merge: - docs/SFT+KTWrapper/ (6 Chinese design docs) - docs/sft_moe_amx/ (21 dev/debug docs) - 12 debug/test example scripts - 6 SFT-specific bench scripts and report Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
This commit is contained in:
parent
22e9915ec9
commit
9544a8960d
41 changed files with 22866 additions and 937 deletions
|
|
@ -164,8 +164,12 @@ class SafeTensorLoader:
|
|||
return tensor.to(device)
|
||||
|
||||
def close_all_handles(self):
|
||||
for handle in self.file_handle_map.values():
|
||||
handle.close()
|
||||
"""Close all file handles and clear the handle map.
|
||||
|
||||
Note: safetensors.safe_open doesn't have a close() method,
|
||||
so we just clear the references and let garbage collection handle cleanup.
|
||||
"""
|
||||
# safetensors.safe_open doesn't have close(), just clear references
|
||||
self.file_handle_map.clear()
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
|
|
@ -202,6 +206,20 @@ class SafeTensorLoader:
|
|||
up_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
gate_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
down_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
# Check if backward weights exist
|
||||
up_bwd_base_key = f"{base_key}.ffn_up_bwd_exps"
|
||||
gate_bwd_base_key = f"{base_key}.ffn_gate_bwd_exps"
|
||||
down_bwd_base_key = f"{base_key}.ffn_down_bwd_exps"
|
||||
has_bwd = self.has_tensor(f"{gate_bwd_base_key}.{0}.numa.{0}.weight")
|
||||
|
||||
if has_bwd:
|
||||
up_bwd_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
gate_bwd_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
down_bwd_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
up_bwd_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
gate_bwd_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
down_bwd_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
|
||||
for numa_id in range(max_numa_id + 1):
|
||||
for expert_id in range(max_experts_count + 1):
|
||||
up_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.weight"
|
||||
|
|
@ -224,7 +242,29 @@ class SafeTensorLoader:
|
|||
up_scales[numa_id].append(up_scale_tensor)
|
||||
gate_scales[numa_id].append(gate_scale_tensor)
|
||||
down_scales[numa_id].append(down_scale_tensor)
|
||||
return {
|
||||
|
||||
# Load backward weights if available
|
||||
if has_bwd:
|
||||
gate_bwd_weights[numa_id].append(
|
||||
self.load_tensor(f"{gate_bwd_base_key}.{expert_id}.numa.{numa_id}.weight", device).numpy()
|
||||
)
|
||||
up_bwd_weights[numa_id].append(
|
||||
self.load_tensor(f"{up_bwd_base_key}.{expert_id}.numa.{numa_id}.weight", device).numpy()
|
||||
)
|
||||
down_bwd_weights[numa_id].append(
|
||||
self.load_tensor(f"{down_bwd_base_key}.{expert_id}.numa.{numa_id}.weight", device).numpy()
|
||||
)
|
||||
gate_bwd_scales[numa_id].append(
|
||||
self.load_tensor(f"{gate_bwd_base_key}.{expert_id}.numa.{numa_id}.scale", device).numpy()
|
||||
)
|
||||
up_bwd_scales[numa_id].append(
|
||||
self.load_tensor(f"{up_bwd_base_key}.{expert_id}.numa.{numa_id}.scale", device).numpy()
|
||||
)
|
||||
down_bwd_scales[numa_id].append(
|
||||
self.load_tensor(f"{down_bwd_base_key}.{expert_id}.numa.{numa_id}.scale", device).numpy()
|
||||
)
|
||||
|
||||
result = {
|
||||
"up": up_weights,
|
||||
"gate": gate_weights,
|
||||
"down": down_weights,
|
||||
|
|
@ -232,6 +272,14 @@ class SafeTensorLoader:
|
|||
"gate_scale": gate_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
if has_bwd:
|
||||
result["gate_bwd"] = gate_bwd_weights
|
||||
result["up_bwd"] = up_bwd_weights
|
||||
result["down_bwd"] = down_bwd_weights
|
||||
result["gate_bwd_scale"] = gate_bwd_scales
|
||||
result["up_bwd_scale"] = up_bwd_scales
|
||||
result["down_bwd_scale"] = down_bwd_scales
|
||||
return result
|
||||
|
||||
def has_tensor(self, name: str):
|
||||
return name in self.tensor_file_map
|
||||
|
|
@ -675,6 +723,111 @@ class CompressedSafeTensorLoader(SafeTensorLoader):
|
|||
}
|
||||
|
||||
|
||||
class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for native BF16 expert weights (no quantization, no scales).
|
||||
|
||||
Supported formats:
|
||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||
|
||||
The format is auto-detected during initialization.
|
||||
"""
|
||||
|
||||
MOE_FORMATS = {
|
||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||
}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
super().__init__(file_path)
|
||||
self._detected_format = None
|
||||
self._detect_format()
|
||||
|
||||
def _detect_format(self):
|
||||
"""Auto-detect the MoE naming format by checking tensor keys."""
|
||||
sample_keys = list(self.tensor_file_map.keys())[:1000]
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.weight" in key:
|
||||
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
|
||||
self._detected_format = "deepseek"
|
||||
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||
|
||||
def _get_experts_prefix(self, base_key: str) -> str:
|
||||
"""Get the experts prefix based on detected format."""
|
||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||
return path_tpl.format(base=base_key)
|
||||
|
||||
def _get_proj_names(self):
|
||||
"""Get projection names (gate, up, down) based on detected format."""
|
||||
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
|
||||
return gate, up, down
|
||||
|
||||
def load_tensor(self, key: str, device: str = "cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
if device == "cpu":
|
||||
return tensor
|
||||
return tensor.to(device)
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load BF16 expert weights (no scales needed).
|
||||
|
||||
Args:
|
||||
base_key: Base key like "model.layers.{layer_index}"
|
||||
device: Target device for tensors
|
||||
|
||||
Returns:
|
||||
Dictionary with keys: gate, up, down, gate_scale (None), up_scale (None), down_scale (None)
|
||||
gate/up/down: list of tensors [expert_id] -> tensor
|
||||
"""
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
|
||||
if expert_count == 0:
|
||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
|
||||
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
|
||||
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
|
||||
|
||||
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
|
||||
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
|
||||
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
|
||||
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
"gate_scale": None,
|
||||
"up_scale": None,
|
||||
"down_scale": None,
|
||||
}
|
||||
|
||||
|
||||
class GGUFLoader:
|
||||
"""
|
||||
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue