feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
mrhaoxx 2026-04-20 13:21:29 +08:00
parent 6d4632b8c7
commit 58d7eabb9b
No known key found for this signature in database
6 changed files with 249 additions and 69 deletions

View file

@ -264,12 +264,17 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
model_container, layers = _get_model_container_and_layers(model, purpose="wrapping")
logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}")
from .arch import detect_fused_experts as _detect_fused
for layer_idx, layer in enumerate(layers):
moe_module = get_moe_module(layer, moe_config)
if moe_module is None:
continue
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method})")
_layer_experts = getattr(moe_module, moe_config.experts_attr, None)
_layer_is_fused = _detect_fused(_layer_experts)
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method}, fused={_layer_is_fused})")
# Only rank 0 loads weights and initializes KT kernel
gate_proj, up_proj, down_proj = None, None, None
@ -312,7 +317,6 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
num_experts_per_tok=moe_config.num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_config.intermediate_size,
gpu_experts_mask=None,
num_gpu_experts=0,
cpuinfer_threads=getattr(cfg, "kt_num_threads", 1),
threadpool_count=threadpool_count,
@ -370,7 +374,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
layer_idx=layer_idx,
lora_experts=lora_experts,
)
layer_wrapper._skip_lora = "SkipLoRA" in kt_method
layer_wrapper._fused_experts = _layer_is_fused
layer_wrapper._lora_rank = lora_rank
setattr(layer, moe_config.moe_layer_attr, layer_wrapper)
# Base weights have been copied into the C++ kernel's internal BufferB format.