feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
mrhaoxx 2026-04-20 13:21:29 +08:00
parent 6d4632b8c7
commit 58d7eabb9b
No known key found for this signature in database
6 changed files with 249 additions and 69 deletions

View file

@ -12,6 +12,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
@ -136,6 +137,21 @@ def get_moe_module(layer: nn.Module, moe_config: MOEArchConfig) -> nn.Module | N
return moe_module
def detect_fused_experts(experts: nn.Module) -> bool:
"""Detect if experts module uses the transformers v5 fused format.
Fused format: a single Module with ``gate_up_proj`` [E, 2I, H] and
``down_proj`` [E, H, I] 3-D tensors instead of a ModuleList of Linear experts.
"""
if experts is None:
return False
gate_up = getattr(experts, "gate_up_proj", None)
down = getattr(experts, "down_proj", None)
if isinstance(gate_up, torch.Tensor) and isinstance(down, torch.Tensor):
return gate_up.dim() == 3 and down.dim() == 3
return False
def _get_layers_prefix(config) -> str:
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
if any(x in arch for x in ["Deepseek", "Qwen", "Mixtral", "Llama"]):

View file

@ -76,7 +76,7 @@ class KTMoEFunction(torch.autograd.Function):
# Rank 0: sync CPU result and split by real lengths
if rank == 0:
cpu_output = wrapper.sync_forward(output_device=original_device)
cpu_output = wrapper.sync_forward_sft(output_device=original_device)
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, hidden_size)
offsets = _qlen_offsets(all_qlens_list)
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
@ -96,7 +96,7 @@ class KTMoEFunction(torch.autograd.Function):
del output_flat
elif wrapper is not None:
# Single-GPU: sync directly
cpu_output = wrapper.sync_forward(output_device=original_device)
cpu_output = wrapper.sync_forward_sft(output_device=original_device)
output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype)
else:
# Broadcast-only rank (no wrapper)

View file

@ -82,10 +82,6 @@ class KTMoELayerWrapper(nn.Module):
# PEFT LoRA tracking (set by kt_adapt_peft_lora)
# _peft_lora_modules: {expert_idx: {proj_name: (lora_A, lora_B)}}
self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None
self._peft_lora_rank: int = 0
self._peft_lora_alpha: float = 0.0
self._skip_lora: bool = False # True when using SkipLoRA backend (no LoRA on experts)
self._lora_pointers_dirty = False
def _apply(self, fn, recurse=True):
@ -210,7 +206,7 @@ class KTMoELayerWrapper(nn.Module):
if rank == 0:
if self.wrapper is None:
raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.")
cpu_output = self.wrapper.sync_forward(output_device=original_device)
cpu_output = self.wrapper.sync_forward_sft(output_device=original_device)
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size)
offsets = _qlen_offsets(all_qlens_list)
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
@ -231,7 +227,7 @@ class KTMoELayerWrapper(nn.Module):
return output
if self.wrapper is not None:
cpu_output = self.wrapper.sync_forward(output_device=original_device)
cpu_output = self.wrapper.sync_forward_sft(output_device=original_device)
output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype)
return output
@ -263,7 +259,18 @@ class KTMoELayerWrapper(nn.Module):
topk_weights = topk_weights.to(torch.bfloat16)
return topk_ids, topk_weights
router_logits = router(hidden_states.view(-1, self.hidden_size))
router_output = router(hidden_states.view(-1, self.hidden_size))
# transformers v5 TopKRouter returns (router_logits, router_scores, router_indices)
# directly — scores/indices are already topk-normalized.
if isinstance(router_output, tuple):
if len(router_output) >= 3:
_logits, topk_weights, topk_ids = router_output[0], router_output[1], router_output[2]
if topk_weights.is_floating_point():
topk_weights = topk_weights.to(torch.bfloat16)
return topk_ids, topk_weights
router_output = router_output[0]
router_logits = router_output
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, self.moe_config.num_experts_per_tok, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@ -328,7 +335,7 @@ class KTMoELayerWrapper(nn.Module):
all_hs = torch.cat(gathered_hs, dim=0)
all_ids = torch.cat(gathered_ids, dim=0)
all_wts = torch.cat(gathered_wts, dim=0)
self.wrapper.submit_forward(
self.wrapper.submit_forward_sft(
all_hs,
all_ids,
all_wts,
@ -357,7 +364,7 @@ class KTMoELayerWrapper(nn.Module):
submit_hs = input_flat.detach()
submit_ids = expert_ids.detach()
submit_wts = weights.detach()
self.wrapper.submit_forward(
self.wrapper.submit_forward_sft(
submit_hs,
submit_ids,
submit_wts,

View file

@ -118,6 +118,10 @@ def get_kt_lora_params(model: nn.Module) -> list[nn.Parameter]:
params.append(lora_A.weight)
if hasattr(lora_B, 'weight') and lora_B.weight.requires_grad:
params.append(lora_B.weight)
# Fused expert LoRA parameters (KT-managed, not PEFT)
fused_params = getattr(wrapper, "_fused_expert_lora_params", None)
if fused_params is not None:
params.extend(fused_params)
# lora_experts parameters (separate feature)
if getattr(wrapper, "lora_experts", None) is not None:
params.extend(wrapper.lora_experts.parameters())
@ -163,7 +167,34 @@ def kt_adapt_peft_lora(model: nn.Module) -> None:
experts_attr = getattr(wrapper, "_experts_attr", "experts")
experts = getattr(wrapper, experts_attr, None)
if experts is None or len(experts) == 0:
if experts is None:
continue
# Fused experts (transformers v5): PEFT cannot auto-attach LoRA to packed
# nn.Parameter tensors. Create KT-managed LoRA buffers with proper init,
# wrap as nn.Parameter for optimizer, and pre-assign .grad for C++ backward.
if getattr(wrapper, "_fused_experts", False):
lora_rank = getattr(wrapper, "_lora_rank", 1)
lora_buffers, lora_grad_buffers, lora_params = _create_fused_expert_lora_buffers(
wrapper, moe_config, lora_rank, torch.bfloat16,
)
if is_rank_0 and wrapper.wrapper is not None:
all_buffers = {}
all_buffers.update(lora_buffers)
all_buffers.update(lora_grad_buffers)
wrapper.wrapper.init_lora_weights(**all_buffers)
logger.info(
f"[kt_adapt_peft_lora] Layer {layer_idx}: fused expert LoRA "
f"(r={lora_rank}, E={moe_config.expert_num})"
)
wrapper._fused_expert_lora_params = lora_params
wrapper._peft_lora_modules = None
adapted_count += 1
continue
if len(experts) == 0:
continue
# Collect references to PEFT LoRA modules for each expert
@ -197,21 +228,11 @@ def kt_adapt_peft_lora(model: nn.Module) -> None:
# Store PEFT LoRA references on wrapper
wrapper._peft_lora_modules = peft_lora_modules
# SkipLoRA mode: if no LoRA found on experts, skip buffer creation
if not peft_lora_modules:
if getattr(wrapper, '_skip_lora', False):
logger.info(
f"[kt_adapt_peft_lora] Layer {layer_idx}: SkipLoRA mode, "
f"no PEFT LoRA on experts — skipping LoRA buffer creation"
)
adapted_count += 1
continue
else:
raise RuntimeError(
f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. "
f"If you intend to train without expert LoRA, use a SkipLoRA backend "
f"(e.g., kt_backend: AMXINT8_SkipLoRA)."
)
raise RuntimeError(
f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. "
f"Check that PEFT lora_target includes expert modules."
)
# Allocate contiguous bf16 buffers and populate with initial PEFT values (all ranks)
lora_buffers = _create_lora_view_buffers(peft_lora_modules, moe_config, torch.bfloat16)
@ -243,6 +264,8 @@ def kt_adapt_peft_lora(model: nn.Module) -> None:
experts = getattr(wrapper, experts_attr, None)
if experts is None:
continue
if getattr(wrapper, "_fused_experts", False):
continue
for expert in experts:
for param_name, param in list(expert.named_parameters()):
if param.requires_grad:
@ -372,6 +395,62 @@ def _create_lora_grad_buffers(
return buffers
def _create_fused_expert_lora_buffers(
wrapper,
moe_config: MOEArchConfig,
lora_rank: int,
dtype: torch.dtype = torch.bfloat16,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], list[nn.Parameter]]:
"""
Create KT-managed LoRA buffers for fused expert modules.
Fused experts store weights as 3D parameters (gate_up_proj [E, 2I, H], down_proj [E, H, I])
rather than per-expert nn.Linear modules. PEFT can't attach per-expert LoRA to these,
so we create our own LoRA buffers that the C++ kernel reads/writes directly.
Returns:
(lora_buffers, lora_grad_buffers, lora_params):
- lora_buffers: dict of weight buffers for C++ init_lora_weights()
- lora_grad_buffers: dict of grad buffers for C++ backward
- lora_params: list of nn.Parameter wrappers for the optimizer
"""
E = moe_config.expert_num
I = moe_config.intermediate_size
H = wrapper.hidden_size
r = lora_rank
logger.info(f"[_create_fused_expert_lora_buffers] E={E}, I={I}, H={H}, r={r}")
lora_buffers = {
"gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
"gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
"up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
"up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
"down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"),
"down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"),
}
for key in ("gate_lora_a", "up_lora_a", "down_lora_a"):
nn.init.kaiming_uniform_(lora_buffers[key].view(E * r, -1), a=math.sqrt(5))
lora_grad_buffers = {
"grad_gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
"grad_gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
"grad_up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
"grad_up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
"grad_down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"),
"grad_down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"),
}
lora_params = []
for key in ("gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"):
param = nn.Parameter(lora_buffers[key], requires_grad=True)
param.grad = lora_grad_buffers[f"grad_{key}"]
lora_params.append(param)
return lora_buffers, lora_grad_buffers, lora_params
# =============================================================================
# PEFT Weight View Replacement
# =============================================================================
@ -510,15 +589,7 @@ def save_lora_experts_to_adapter(model: nn.Module, output_dir: str) -> None:
from safetensors import safe_open
from safetensors.torch import save_file
wrappers = getattr(model, "_kt_wrappers", [])
if not wrappers:
base_model = model
for attr in ["base_model", "model"]:
if hasattr(base_model, attr):
base_model = getattr(base_model, attr)
wrappers = getattr(base_model, "_kt_wrappers", [])
if wrappers:
break
wrappers = _find_kt_wrappers(model) or []
if not wrappers:
logger.warning("No KT wrappers found, skipping LoRA Experts saving")
return
@ -568,25 +639,80 @@ def save_kt_moe_to_adapter(model: nn.Module, output_dir: str) -> None:
Note: Per-expert PEFT LoRA is saved by PEFT directly, not here.
This function only handles lora_experts (a separate feature).
"""
wrappers = getattr(model, "_kt_wrappers", [])
if not wrappers:
base_model = model
for attr in ["base_model", "model"]:
if hasattr(base_model, attr):
base_model = getattr(base_model, attr)
wrappers = getattr(base_model, "_kt_wrappers", [])
if wrappers:
break
wrappers = _find_kt_wrappers(model) or []
if not wrappers:
logger.info("[save_kt_moe] No KT wrappers found, skipping")
return
has_lora_experts = any(w.lora_experts is not None for w in wrappers)
has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers)
if has_lora_experts:
save_lora_experts_to_adapter(model, output_dir)
else:
logger.info("[save_kt_moe] No lora_experts in KT wrappers")
if has_fused_lora:
_save_fused_expert_lora(wrappers, output_dir)
if not has_lora_experts and not has_fused_lora:
logger.info("[save_kt_moe] No lora_experts or fused expert LoRA in KT wrappers")
def _save_fused_expert_lora(wrappers: list, output_dir: str) -> None:
"""Save fused expert LoRA params to a safetensors file."""
from safetensors.torch import save_file
names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"]
tensors = {}
for w in wrappers:
fused = getattr(w, "_fused_expert_lora_params", None)
if fused is None:
continue
for param, name in zip(fused, names):
key = f"layers.{w.layer_idx}.experts.{name}"
tensors[key] = param.data.clone()
if tensors:
path = os.path.join(output_dir, "fused_expert_lora.safetensors")
save_file(tensors, path)
logger.info(f"[save_kt_moe] Saved {len(tensors)} fused expert LoRA tensors to {path}")
def _load_fused_expert_lora(wrappers: list, adapter_path: str) -> None:
"""Load fused expert LoRA params from a safetensors file into existing wrapper buffers."""
path = os.path.join(adapter_path, "fused_expert_lora.safetensors")
if not os.path.isfile(path):
logger.warning(f"No fused_expert_lora.safetensors found at {adapter_path}")
return
from safetensors.torch import load_file
saved = load_file(path)
names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"]
wrapper_map = {w.layer_idx: w for w in wrappers}
loaded_count = 0
for key, tensor in saved.items():
parts = key.split(".")
if len(parts) != 4 or parts[0] != "layers" or parts[2] != "experts":
logger.warning(f"Unexpected key in fused_expert_lora.safetensors: {key}")
continue
layer_idx = int(parts[1])
name = parts[3]
if name not in names:
continue
wrapper = wrapper_map.get(layer_idx)
if wrapper is None:
continue
fused = getattr(wrapper, "_fused_expert_lora_params", None)
if fused is None:
continue
param_idx = names.index(name)
fused[param_idx].data.copy_(tensor)
loaded_count += 1
logger.info(f"[_load_fused_expert_lora] Loaded {loaded_count} tensors from {path}")
def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None:
@ -595,15 +721,7 @@ def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None:
"""
from safetensors import safe_open
wrappers = getattr(model, "_kt_wrappers", [])
if not wrappers:
base_model = model
for attr in ["base_model", "model"]:
if hasattr(base_model, attr):
base_model = getattr(base_model, attr)
wrappers = getattr(base_model, "_kt_wrappers", [])
if wrappers:
break
wrappers = _find_kt_wrappers(model) or []
if not wrappers:
logger.warning("No KT wrappers found, skipping LoRA Experts loading")
return
@ -667,22 +785,19 @@ def load_kt_moe_from_adapter(model: nn.Module, adapter_path: str) -> None:
Note: Per-expert PEFT LoRA is loaded by PEFT directly, not here.
This function only handles lora_experts (a separate feature).
"""
wrappers = getattr(model, "_kt_wrappers", [])
if not wrappers:
base_model = model
for attr in ["base_model", "model"]:
if hasattr(base_model, attr):
base_model = getattr(base_model, attr)
wrappers = getattr(base_model, "_kt_wrappers", [])
if wrappers:
break
wrappers = _find_kt_wrappers(model) or []
if not wrappers:
logger.warning("No KT wrappers found, skipping KT MoE loading")
return
has_lora_experts = any(w.lora_experts is not None for w in wrappers)
has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers)
if has_lora_experts:
load_lora_experts_from_adapter(model, adapter_path)
else:
logger.info("No lora_experts in KT wrappers (PEFT LoRA is loaded by PEFT directly)")
if has_fused_lora:
_load_fused_expert_lora(wrappers, adapter_path)
if not has_lora_experts and not has_fused_lora:
logger.info("No lora_experts or fused expert LoRA in KT wrappers")

View file

@ -40,8 +40,28 @@ def extract_moe_weights(
Returns (gate_proj, up_proj, down_proj) with shape
[expert_num, out_features, in_features].
Supports two formats:
- ModuleList of Linear experts (transformers v4 style)
- Fused Parameters (transformers v5 style): single module with
``gate_up_proj`` [E, 2*I, H] and ``down_proj`` [E, H, I] tensors.
"""
from .arch import detect_fused_experts
experts = getattr(moe_module, moe_config.experts_attr)
# Fused format (transformers v5): a single nn.Module with gate_up_proj/down_proj tensors
if detect_fused_experts(experts):
gate_up = getattr(experts, "gate_up_proj").data
down_fused = getattr(experts, "down_proj").data
# gate_up_proj is [E, 2*I, H], split into gate [E, I, H] and up [E, I, H]
intermediate = gate_up.shape[1] // 2
gate_proj = gate_up[:, :intermediate, :].contiguous()
up_proj = gate_up[:, intermediate:, :].contiguous()
# down_proj is already [E, H, I]
down_proj = down_fused.contiguous()
return gate_proj, up_proj, down_proj
gate_name, up_name, down_name = moe_config.weight_names
gather_params: list[torch.nn.Parameter] = []
@ -92,10 +112,27 @@ def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchCon
"""
Clear original expert weights to free memory after KT weights are loaded.
"""
from .arch import detect_fused_experts
experts = getattr(moe_module, moe_config.experts_attr, None)
if experts is None:
return
# Fused format: replace gate_up_proj/down_proj tensors with zero-storage placeholders
if detect_fused_experts(experts):
for name in ("gate_up_proj", "down_proj"):
param = getattr(experts, name, None)
if not isinstance(param, torch.nn.Parameter):
continue
original_dtype = param.dtype
tiny_storage = torch.UntypedStorage(1, device="cpu")
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
tiny_storage, storage_offset=0, size=param.shape,
stride=[0] * len(param.shape),
)
experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False)
return
def _iter_weight_params():
for expert in experts:
for weight_name in moe_config.weight_names:

View file

@ -264,12 +264,17 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
model_container, layers = _get_model_container_and_layers(model, purpose="wrapping")
logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}")
from .arch import detect_fused_experts as _detect_fused
for layer_idx, layer in enumerate(layers):
moe_module = get_moe_module(layer, moe_config)
if moe_module is None:
continue
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method})")
_layer_experts = getattr(moe_module, moe_config.experts_attr, None)
_layer_is_fused = _detect_fused(_layer_experts)
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method}, fused={_layer_is_fused})")
# Only rank 0 loads weights and initializes KT kernel
gate_proj, up_proj, down_proj = None, None, None
@ -312,7 +317,6 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
num_experts_per_tok=moe_config.num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_config.intermediate_size,
gpu_experts_mask=None,
num_gpu_experts=0,
cpuinfer_threads=getattr(cfg, "kt_num_threads", 1),
threadpool_count=threadpool_count,
@ -370,7 +374,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
layer_idx=layer_idx,
lora_experts=lora_experts,
)
layer_wrapper._skip_lora = "SkipLoRA" in kt_method
layer_wrapper._fused_experts = _layer_is_fused
layer_wrapper._lora_rank = lora_rank
setattr(layer, moe_config.moe_layer_attr, layer_wrapper)
# Base weights have been copied into the C++ kernel's internal BufferB format.