feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
mrhaoxx 2026-04-20 13:21:29 +08:00
parent 6d4632b8c7
commit 58d7eabb9b
No known key found for this signature in database
6 changed files with 249 additions and 69 deletions

View file

@ -40,8 +40,28 @@ def extract_moe_weights(
Returns (gate_proj, up_proj, down_proj) with shape
[expert_num, out_features, in_features].
Supports two formats:
- ModuleList of Linear experts (transformers v4 style)
- Fused Parameters (transformers v5 style): single module with
``gate_up_proj`` [E, 2*I, H] and ``down_proj`` [E, H, I] tensors.
"""
from .arch import detect_fused_experts
experts = getattr(moe_module, moe_config.experts_attr)
# Fused format (transformers v5): a single nn.Module with gate_up_proj/down_proj tensors
if detect_fused_experts(experts):
gate_up = getattr(experts, "gate_up_proj").data
down_fused = getattr(experts, "down_proj").data
# gate_up_proj is [E, 2*I, H], split into gate [E, I, H] and up [E, I, H]
intermediate = gate_up.shape[1] // 2
gate_proj = gate_up[:, :intermediate, :].contiguous()
up_proj = gate_up[:, intermediate:, :].contiguous()
# down_proj is already [E, H, I]
down_proj = down_fused.contiguous()
return gate_proj, up_proj, down_proj
gate_name, up_name, down_name = moe_config.weight_names
gather_params: list[torch.nn.Parameter] = []
@ -92,10 +112,27 @@ def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchCon
"""
Clear original expert weights to free memory after KT weights are loaded.
"""
from .arch import detect_fused_experts
experts = getattr(moe_module, moe_config.experts_attr, None)
if experts is None:
return
# Fused format: replace gate_up_proj/down_proj tensors with zero-storage placeholders
if detect_fused_experts(experts):
for name in ("gate_up_proj", "down_proj"):
param = getattr(experts, name, None)
if not isinstance(param, torch.nn.Parameter):
continue
original_dtype = param.dtype
tiny_storage = torch.UntypedStorage(1, device="cpu")
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
tiny_storage, storage_offset=0, size=param.shape,
stride=[0] * len(param.shape),
)
experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False)
return
def _iter_weight_params():
for expert in experts:
for weight_name in moe_config.weight_names: