mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 20:00:06 +00:00
feat(sft): support transformers v5 fused expert format
Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters (gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert nn.Linear modules. PEFT cannot attach LoRA to these, so we create KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers for the optimizer, and pre-assigned .grad for C++ backward. - arch.py: detect_fused_experts() detection - weights.py: fused format extraction and weight clearing - wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank - lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA, get_kt_lora_params collects fused params, deduplicate wrapper finding - layer.py: handle v5 TopKRouter tuple output, remove dead code - autograd.py: sync_forward_sft/submit_forward_sft API rename Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
6d4632b8c7
commit
58d7eabb9b
6 changed files with 249 additions and 69 deletions
|
|
@ -264,12 +264,17 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
model_container, layers = _get_model_container_and_layers(model, purpose="wrapping")
|
||||
logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}")
|
||||
|
||||
from .arch import detect_fused_experts as _detect_fused
|
||||
|
||||
for layer_idx, layer in enumerate(layers):
|
||||
moe_module = get_moe_module(layer, moe_config)
|
||||
if moe_module is None:
|
||||
continue
|
||||
|
||||
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method})")
|
||||
_layer_experts = getattr(moe_module, moe_config.experts_attr, None)
|
||||
_layer_is_fused = _detect_fused(_layer_experts)
|
||||
|
||||
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method}, fused={_layer_is_fused})")
|
||||
|
||||
# Only rank 0 loads weights and initializes KT kernel
|
||||
gate_proj, up_proj, down_proj = None, None, None
|
||||
|
|
@ -312,7 +317,6 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
num_experts_per_tok=moe_config.num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_config.intermediate_size,
|
||||
gpu_experts_mask=None,
|
||||
num_gpu_experts=0,
|
||||
cpuinfer_threads=getattr(cfg, "kt_num_threads", 1),
|
||||
threadpool_count=threadpool_count,
|
||||
|
|
@ -370,7 +374,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
layer_idx=layer_idx,
|
||||
lora_experts=lora_experts,
|
||||
)
|
||||
layer_wrapper._skip_lora = "SkipLoRA" in kt_method
|
||||
layer_wrapper._fused_experts = _layer_is_fused
|
||||
layer_wrapper._lora_rank = lora_rank
|
||||
|
||||
setattr(layer, moe_config.moe_layer_attr, layer_wrapper)
|
||||
# Base weights have been copied into the C++ kernel's internal BufferB format.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue