mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-30 21:00:07 +00:00
feat(sft): support transformers v5 fused expert format
Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters (gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert nn.Linear modules. PEFT cannot attach LoRA to these, so we create KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers for the optimizer, and pre-assigned .grad for C++ backward. - arch.py: detect_fused_experts() detection - weights.py: fused format extraction and weight clearing - wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank - lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA, get_kt_lora_params collects fused params, deduplicate wrapper finding - layer.py: handle v5 TopKRouter tuple output, remove dead code - autograd.py: sync_forward_sft/submit_forward_sft API rename Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
6d4632b8c7
commit
58d7eabb9b
6 changed files with 249 additions and 69 deletions
|
|
@ -40,8 +40,28 @@ def extract_moe_weights(
|
|||
|
||||
Returns (gate_proj, up_proj, down_proj) with shape
|
||||
[expert_num, out_features, in_features].
|
||||
|
||||
Supports two formats:
|
||||
- ModuleList of Linear experts (transformers v4 style)
|
||||
- Fused Parameters (transformers v5 style): single module with
|
||||
``gate_up_proj`` [E, 2*I, H] and ``down_proj`` [E, H, I] tensors.
|
||||
"""
|
||||
from .arch import detect_fused_experts
|
||||
|
||||
experts = getattr(moe_module, moe_config.experts_attr)
|
||||
|
||||
# Fused format (transformers v5): a single nn.Module with gate_up_proj/down_proj tensors
|
||||
if detect_fused_experts(experts):
|
||||
gate_up = getattr(experts, "gate_up_proj").data
|
||||
down_fused = getattr(experts, "down_proj").data
|
||||
# gate_up_proj is [E, 2*I, H], split into gate [E, I, H] and up [E, I, H]
|
||||
intermediate = gate_up.shape[1] // 2
|
||||
gate_proj = gate_up[:, :intermediate, :].contiguous()
|
||||
up_proj = gate_up[:, intermediate:, :].contiguous()
|
||||
# down_proj is already [E, H, I]
|
||||
down_proj = down_fused.contiguous()
|
||||
return gate_proj, up_proj, down_proj
|
||||
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
|
||||
gather_params: list[torch.nn.Parameter] = []
|
||||
|
|
@ -92,10 +112,27 @@ def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchCon
|
|||
"""
|
||||
Clear original expert weights to free memory after KT weights are loaded.
|
||||
"""
|
||||
from .arch import detect_fused_experts
|
||||
|
||||
experts = getattr(moe_module, moe_config.experts_attr, None)
|
||||
if experts is None:
|
||||
return
|
||||
|
||||
# Fused format: replace gate_up_proj/down_proj tensors with zero-storage placeholders
|
||||
if detect_fused_experts(experts):
|
||||
for name in ("gate_up_proj", "down_proj"):
|
||||
param = getattr(experts, name, None)
|
||||
if not isinstance(param, torch.nn.Parameter):
|
||||
continue
|
||||
original_dtype = param.dtype
|
||||
tiny_storage = torch.UntypedStorage(1, device="cpu")
|
||||
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
|
||||
tiny_storage, storage_offset=0, size=param.shape,
|
||||
stride=[0] * len(param.shape),
|
||||
)
|
||||
experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False)
|
||||
return
|
||||
|
||||
def _iter_weight_params():
|
||||
for expert in experts:
|
||||
for weight_name in moe_config.weight_names:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue