mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
feat(sft): support transformers v5 fused expert format
Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters (gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert nn.Linear modules. PEFT cannot attach LoRA to these, so we create KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers for the optimizer, and pre-assigned .grad for C++ backward. - arch.py: detect_fused_experts() detection - weights.py: fused format extraction and weight clearing - wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank - lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA, get_kt_lora_params collects fused params, deduplicate wrapper finding - layer.py: handle v5 TopKRouter tuple output, remove dead code - autograd.py: sync_forward_sft/submit_forward_sft API rename Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
6d4632b8c7
commit
58d7eabb9b
6 changed files with 249 additions and 69 deletions
|
|
@ -12,6 +12,7 @@ from __future__ import annotations
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -136,6 +137,21 @@ def get_moe_module(layer: nn.Module, moe_config: MOEArchConfig) -> nn.Module | N
|
|||
return moe_module
|
||||
|
||||
|
||||
def detect_fused_experts(experts: nn.Module) -> bool:
|
||||
"""Detect if experts module uses the transformers v5 fused format.
|
||||
|
||||
Fused format: a single Module with ``gate_up_proj`` [E, 2I, H] and
|
||||
``down_proj`` [E, H, I] 3-D tensors instead of a ModuleList of Linear experts.
|
||||
"""
|
||||
if experts is None:
|
||||
return False
|
||||
gate_up = getattr(experts, "gate_up_proj", None)
|
||||
down = getattr(experts, "down_proj", None)
|
||||
if isinstance(gate_up, torch.Tensor) and isinstance(down, torch.Tensor):
|
||||
return gate_up.dim() == 3 and down.dim() == 3
|
||||
return False
|
||||
|
||||
|
||||
def _get_layers_prefix(config) -> str:
|
||||
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
|
||||
if any(x in arch for x in ["Deepseek", "Qwen", "Mixtral", "Llama"]):
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class KTMoEFunction(torch.autograd.Function):
|
|||
|
||||
# Rank 0: sync CPU result and split by real lengths
|
||||
if rank == 0:
|
||||
cpu_output = wrapper.sync_forward(output_device=original_device)
|
||||
cpu_output = wrapper.sync_forward_sft(output_device=original_device)
|
||||
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, hidden_size)
|
||||
offsets = _qlen_offsets(all_qlens_list)
|
||||
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
|
||||
|
|
@ -96,7 +96,7 @@ class KTMoEFunction(torch.autograd.Function):
|
|||
del output_flat
|
||||
elif wrapper is not None:
|
||||
# Single-GPU: sync directly
|
||||
cpu_output = wrapper.sync_forward(output_device=original_device)
|
||||
cpu_output = wrapper.sync_forward_sft(output_device=original_device)
|
||||
output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype)
|
||||
else:
|
||||
# Broadcast-only rank (no wrapper)
|
||||
|
|
|
|||
|
|
@ -82,10 +82,6 @@ class KTMoELayerWrapper(nn.Module):
|
|||
# PEFT LoRA tracking (set by kt_adapt_peft_lora)
|
||||
# _peft_lora_modules: {expert_idx: {proj_name: (lora_A, lora_B)}}
|
||||
self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None
|
||||
self._peft_lora_rank: int = 0
|
||||
self._peft_lora_alpha: float = 0.0
|
||||
self._skip_lora: bool = False # True when using SkipLoRA backend (no LoRA on experts)
|
||||
|
||||
self._lora_pointers_dirty = False
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
|
|
@ -210,7 +206,7 @@ class KTMoELayerWrapper(nn.Module):
|
|||
if rank == 0:
|
||||
if self.wrapper is None:
|
||||
raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.")
|
||||
cpu_output = self.wrapper.sync_forward(output_device=original_device)
|
||||
cpu_output = self.wrapper.sync_forward_sft(output_device=original_device)
|
||||
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size)
|
||||
offsets = _qlen_offsets(all_qlens_list)
|
||||
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
|
||||
|
|
@ -231,7 +227,7 @@ class KTMoELayerWrapper(nn.Module):
|
|||
return output
|
||||
|
||||
if self.wrapper is not None:
|
||||
cpu_output = self.wrapper.sync_forward(output_device=original_device)
|
||||
cpu_output = self.wrapper.sync_forward_sft(output_device=original_device)
|
||||
output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype)
|
||||
return output
|
||||
|
||||
|
|
@ -263,7 +259,18 @@ class KTMoELayerWrapper(nn.Module):
|
|||
topk_weights = topk_weights.to(torch.bfloat16)
|
||||
return topk_ids, topk_weights
|
||||
|
||||
router_logits = router(hidden_states.view(-1, self.hidden_size))
|
||||
router_output = router(hidden_states.view(-1, self.hidden_size))
|
||||
# transformers v5 TopKRouter returns (router_logits, router_scores, router_indices)
|
||||
# directly — scores/indices are already topk-normalized.
|
||||
if isinstance(router_output, tuple):
|
||||
if len(router_output) >= 3:
|
||||
_logits, topk_weights, topk_ids = router_output[0], router_output[1], router_output[2]
|
||||
if topk_weights.is_floating_point():
|
||||
topk_weights = topk_weights.to(torch.bfloat16)
|
||||
return topk_ids, topk_weights
|
||||
router_output = router_output[0]
|
||||
|
||||
router_logits = router_output
|
||||
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(routing_weights, self.moe_config.num_experts_per_tok, dim=-1)
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
|
@ -328,7 +335,7 @@ class KTMoELayerWrapper(nn.Module):
|
|||
all_hs = torch.cat(gathered_hs, dim=0)
|
||||
all_ids = torch.cat(gathered_ids, dim=0)
|
||||
all_wts = torch.cat(gathered_wts, dim=0)
|
||||
self.wrapper.submit_forward(
|
||||
self.wrapper.submit_forward_sft(
|
||||
all_hs,
|
||||
all_ids,
|
||||
all_wts,
|
||||
|
|
@ -357,7 +364,7 @@ class KTMoELayerWrapper(nn.Module):
|
|||
submit_hs = input_flat.detach()
|
||||
submit_ids = expert_ids.detach()
|
||||
submit_wts = weights.detach()
|
||||
self.wrapper.submit_forward(
|
||||
self.wrapper.submit_forward_sft(
|
||||
submit_hs,
|
||||
submit_ids,
|
||||
submit_wts,
|
||||
|
|
|
|||
|
|
@ -118,6 +118,10 @@ def get_kt_lora_params(model: nn.Module) -> list[nn.Parameter]:
|
|||
params.append(lora_A.weight)
|
||||
if hasattr(lora_B, 'weight') and lora_B.weight.requires_grad:
|
||||
params.append(lora_B.weight)
|
||||
# Fused expert LoRA parameters (KT-managed, not PEFT)
|
||||
fused_params = getattr(wrapper, "_fused_expert_lora_params", None)
|
||||
if fused_params is not None:
|
||||
params.extend(fused_params)
|
||||
# lora_experts parameters (separate feature)
|
||||
if getattr(wrapper, "lora_experts", None) is not None:
|
||||
params.extend(wrapper.lora_experts.parameters())
|
||||
|
|
@ -163,7 +167,34 @@ def kt_adapt_peft_lora(model: nn.Module) -> None:
|
|||
experts_attr = getattr(wrapper, "_experts_attr", "experts")
|
||||
experts = getattr(wrapper, experts_attr, None)
|
||||
|
||||
if experts is None or len(experts) == 0:
|
||||
if experts is None:
|
||||
continue
|
||||
|
||||
# Fused experts (transformers v5): PEFT cannot auto-attach LoRA to packed
|
||||
# nn.Parameter tensors. Create KT-managed LoRA buffers with proper init,
|
||||
# wrap as nn.Parameter for optimizer, and pre-assign .grad for C++ backward.
|
||||
if getattr(wrapper, "_fused_experts", False):
|
||||
lora_rank = getattr(wrapper, "_lora_rank", 1)
|
||||
lora_buffers, lora_grad_buffers, lora_params = _create_fused_expert_lora_buffers(
|
||||
wrapper, moe_config, lora_rank, torch.bfloat16,
|
||||
)
|
||||
|
||||
if is_rank_0 and wrapper.wrapper is not None:
|
||||
all_buffers = {}
|
||||
all_buffers.update(lora_buffers)
|
||||
all_buffers.update(lora_grad_buffers)
|
||||
wrapper.wrapper.init_lora_weights(**all_buffers)
|
||||
logger.info(
|
||||
f"[kt_adapt_peft_lora] Layer {layer_idx}: fused expert LoRA "
|
||||
f"(r={lora_rank}, E={moe_config.expert_num})"
|
||||
)
|
||||
|
||||
wrapper._fused_expert_lora_params = lora_params
|
||||
wrapper._peft_lora_modules = None
|
||||
adapted_count += 1
|
||||
continue
|
||||
|
||||
if len(experts) == 0:
|
||||
continue
|
||||
|
||||
# Collect references to PEFT LoRA modules for each expert
|
||||
|
|
@ -197,21 +228,11 @@ def kt_adapt_peft_lora(model: nn.Module) -> None:
|
|||
# Store PEFT LoRA references on wrapper
|
||||
wrapper._peft_lora_modules = peft_lora_modules
|
||||
|
||||
# SkipLoRA mode: if no LoRA found on experts, skip buffer creation
|
||||
if not peft_lora_modules:
|
||||
if getattr(wrapper, '_skip_lora', False):
|
||||
logger.info(
|
||||
f"[kt_adapt_peft_lora] Layer {layer_idx}: SkipLoRA mode, "
|
||||
f"no PEFT LoRA on experts — skipping LoRA buffer creation"
|
||||
)
|
||||
adapted_count += 1
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. "
|
||||
f"If you intend to train without expert LoRA, use a SkipLoRA backend "
|
||||
f"(e.g., kt_backend: AMXINT8_SkipLoRA)."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. "
|
||||
f"Check that PEFT lora_target includes expert modules."
|
||||
)
|
||||
|
||||
# Allocate contiguous bf16 buffers and populate with initial PEFT values (all ranks)
|
||||
lora_buffers = _create_lora_view_buffers(peft_lora_modules, moe_config, torch.bfloat16)
|
||||
|
|
@ -243,6 +264,8 @@ def kt_adapt_peft_lora(model: nn.Module) -> None:
|
|||
experts = getattr(wrapper, experts_attr, None)
|
||||
if experts is None:
|
||||
continue
|
||||
if getattr(wrapper, "_fused_experts", False):
|
||||
continue
|
||||
for expert in experts:
|
||||
for param_name, param in list(expert.named_parameters()):
|
||||
if param.requires_grad:
|
||||
|
|
@ -372,6 +395,62 @@ def _create_lora_grad_buffers(
|
|||
return buffers
|
||||
|
||||
|
||||
def _create_fused_expert_lora_buffers(
|
||||
wrapper,
|
||||
moe_config: MOEArchConfig,
|
||||
lora_rank: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], list[nn.Parameter]]:
|
||||
"""
|
||||
Create KT-managed LoRA buffers for fused expert modules.
|
||||
|
||||
Fused experts store weights as 3D parameters (gate_up_proj [E, 2I, H], down_proj [E, H, I])
|
||||
rather than per-expert nn.Linear modules. PEFT can't attach per-expert LoRA to these,
|
||||
so we create our own LoRA buffers that the C++ kernel reads/writes directly.
|
||||
|
||||
Returns:
|
||||
(lora_buffers, lora_grad_buffers, lora_params):
|
||||
- lora_buffers: dict of weight buffers for C++ init_lora_weights()
|
||||
- lora_grad_buffers: dict of grad buffers for C++ backward
|
||||
- lora_params: list of nn.Parameter wrappers for the optimizer
|
||||
"""
|
||||
E = moe_config.expert_num
|
||||
I = moe_config.intermediate_size
|
||||
H = wrapper.hidden_size
|
||||
r = lora_rank
|
||||
|
||||
logger.info(f"[_create_fused_expert_lora_buffers] E={E}, I={I}, H={H}, r={r}")
|
||||
|
||||
lora_buffers = {
|
||||
"gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
|
||||
"gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
|
||||
"up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
|
||||
"up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
|
||||
"down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"),
|
||||
"down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"),
|
||||
}
|
||||
|
||||
for key in ("gate_lora_a", "up_lora_a", "down_lora_a"):
|
||||
nn.init.kaiming_uniform_(lora_buffers[key].view(E * r, -1), a=math.sqrt(5))
|
||||
|
||||
lora_grad_buffers = {
|
||||
"grad_gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
|
||||
"grad_gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
|
||||
"grad_up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
|
||||
"grad_up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
|
||||
"grad_down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"),
|
||||
"grad_down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"),
|
||||
}
|
||||
|
||||
lora_params = []
|
||||
for key in ("gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"):
|
||||
param = nn.Parameter(lora_buffers[key], requires_grad=True)
|
||||
param.grad = lora_grad_buffers[f"grad_{key}"]
|
||||
lora_params.append(param)
|
||||
|
||||
return lora_buffers, lora_grad_buffers, lora_params
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PEFT Weight View Replacement
|
||||
# =============================================================================
|
||||
|
|
@ -510,15 +589,7 @@ def save_lora_experts_to_adapter(model: nn.Module, output_dir: str) -> None:
|
|||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
wrappers = getattr(model, "_kt_wrappers", [])
|
||||
if not wrappers:
|
||||
base_model = model
|
||||
for attr in ["base_model", "model"]:
|
||||
if hasattr(base_model, attr):
|
||||
base_model = getattr(base_model, attr)
|
||||
wrappers = getattr(base_model, "_kt_wrappers", [])
|
||||
if wrappers:
|
||||
break
|
||||
wrappers = _find_kt_wrappers(model) or []
|
||||
if not wrappers:
|
||||
logger.warning("No KT wrappers found, skipping LoRA Experts saving")
|
||||
return
|
||||
|
|
@ -568,25 +639,80 @@ def save_kt_moe_to_adapter(model: nn.Module, output_dir: str) -> None:
|
|||
Note: Per-expert PEFT LoRA is saved by PEFT directly, not here.
|
||||
This function only handles lora_experts (a separate feature).
|
||||
"""
|
||||
wrappers = getattr(model, "_kt_wrappers", [])
|
||||
if not wrappers:
|
||||
base_model = model
|
||||
for attr in ["base_model", "model"]:
|
||||
if hasattr(base_model, attr):
|
||||
base_model = getattr(base_model, attr)
|
||||
wrappers = getattr(base_model, "_kt_wrappers", [])
|
||||
if wrappers:
|
||||
break
|
||||
wrappers = _find_kt_wrappers(model) or []
|
||||
if not wrappers:
|
||||
logger.info("[save_kt_moe] No KT wrappers found, skipping")
|
||||
return
|
||||
|
||||
has_lora_experts = any(w.lora_experts is not None for w in wrappers)
|
||||
has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers)
|
||||
|
||||
if has_lora_experts:
|
||||
save_lora_experts_to_adapter(model, output_dir)
|
||||
else:
|
||||
logger.info("[save_kt_moe] No lora_experts in KT wrappers")
|
||||
|
||||
if has_fused_lora:
|
||||
_save_fused_expert_lora(wrappers, output_dir)
|
||||
|
||||
if not has_lora_experts and not has_fused_lora:
|
||||
logger.info("[save_kt_moe] No lora_experts or fused expert LoRA in KT wrappers")
|
||||
|
||||
|
||||
def _save_fused_expert_lora(wrappers: list, output_dir: str) -> None:
|
||||
"""Save fused expert LoRA params to a safetensors file."""
|
||||
from safetensors.torch import save_file
|
||||
|
||||
names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"]
|
||||
tensors = {}
|
||||
for w in wrappers:
|
||||
fused = getattr(w, "_fused_expert_lora_params", None)
|
||||
if fused is None:
|
||||
continue
|
||||
for param, name in zip(fused, names):
|
||||
key = f"layers.{w.layer_idx}.experts.{name}"
|
||||
tensors[key] = param.data.clone()
|
||||
|
||||
if tensors:
|
||||
path = os.path.join(output_dir, "fused_expert_lora.safetensors")
|
||||
save_file(tensors, path)
|
||||
logger.info(f"[save_kt_moe] Saved {len(tensors)} fused expert LoRA tensors to {path}")
|
||||
|
||||
|
||||
def _load_fused_expert_lora(wrappers: list, adapter_path: str) -> None:
|
||||
"""Load fused expert LoRA params from a safetensors file into existing wrapper buffers."""
|
||||
path = os.path.join(adapter_path, "fused_expert_lora.safetensors")
|
||||
if not os.path.isfile(path):
|
||||
logger.warning(f"No fused_expert_lora.safetensors found at {adapter_path}")
|
||||
return
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
saved = load_file(path)
|
||||
names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"]
|
||||
wrapper_map = {w.layer_idx: w for w in wrappers}
|
||||
loaded_count = 0
|
||||
|
||||
for key, tensor in saved.items():
|
||||
parts = key.split(".")
|
||||
if len(parts) != 4 or parts[0] != "layers" or parts[2] != "experts":
|
||||
logger.warning(f"Unexpected key in fused_expert_lora.safetensors: {key}")
|
||||
continue
|
||||
layer_idx = int(parts[1])
|
||||
name = parts[3]
|
||||
if name not in names:
|
||||
continue
|
||||
|
||||
wrapper = wrapper_map.get(layer_idx)
|
||||
if wrapper is None:
|
||||
continue
|
||||
fused = getattr(wrapper, "_fused_expert_lora_params", None)
|
||||
if fused is None:
|
||||
continue
|
||||
|
||||
param_idx = names.index(name)
|
||||
fused[param_idx].data.copy_(tensor)
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"[_load_fused_expert_lora] Loaded {loaded_count} tensors from {path}")
|
||||
|
||||
|
||||
def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None:
|
||||
|
|
@ -595,15 +721,7 @@ def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None:
|
|||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
wrappers = getattr(model, "_kt_wrappers", [])
|
||||
if not wrappers:
|
||||
base_model = model
|
||||
for attr in ["base_model", "model"]:
|
||||
if hasattr(base_model, attr):
|
||||
base_model = getattr(base_model, attr)
|
||||
wrappers = getattr(base_model, "_kt_wrappers", [])
|
||||
if wrappers:
|
||||
break
|
||||
wrappers = _find_kt_wrappers(model) or []
|
||||
if not wrappers:
|
||||
logger.warning("No KT wrappers found, skipping LoRA Experts loading")
|
||||
return
|
||||
|
|
@ -667,22 +785,19 @@ def load_kt_moe_from_adapter(model: nn.Module, adapter_path: str) -> None:
|
|||
Note: Per-expert PEFT LoRA is loaded by PEFT directly, not here.
|
||||
This function only handles lora_experts (a separate feature).
|
||||
"""
|
||||
wrappers = getattr(model, "_kt_wrappers", [])
|
||||
if not wrappers:
|
||||
base_model = model
|
||||
for attr in ["base_model", "model"]:
|
||||
if hasattr(base_model, attr):
|
||||
base_model = getattr(base_model, attr)
|
||||
wrappers = getattr(base_model, "_kt_wrappers", [])
|
||||
if wrappers:
|
||||
break
|
||||
wrappers = _find_kt_wrappers(model) or []
|
||||
if not wrappers:
|
||||
logger.warning("No KT wrappers found, skipping KT MoE loading")
|
||||
return
|
||||
|
||||
has_lora_experts = any(w.lora_experts is not None for w in wrappers)
|
||||
has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers)
|
||||
|
||||
if has_lora_experts:
|
||||
load_lora_experts_from_adapter(model, adapter_path)
|
||||
else:
|
||||
logger.info("No lora_experts in KT wrappers (PEFT LoRA is loaded by PEFT directly)")
|
||||
|
||||
if has_fused_lora:
|
||||
_load_fused_expert_lora(wrappers, adapter_path)
|
||||
|
||||
if not has_lora_experts and not has_fused_lora:
|
||||
logger.info("No lora_experts or fused expert LoRA in KT wrappers")
|
||||
|
|
|
|||
|
|
@ -40,8 +40,28 @@ def extract_moe_weights(
|
|||
|
||||
Returns (gate_proj, up_proj, down_proj) with shape
|
||||
[expert_num, out_features, in_features].
|
||||
|
||||
Supports two formats:
|
||||
- ModuleList of Linear experts (transformers v4 style)
|
||||
- Fused Parameters (transformers v5 style): single module with
|
||||
``gate_up_proj`` [E, 2*I, H] and ``down_proj`` [E, H, I] tensors.
|
||||
"""
|
||||
from .arch import detect_fused_experts
|
||||
|
||||
experts = getattr(moe_module, moe_config.experts_attr)
|
||||
|
||||
# Fused format (transformers v5): a single nn.Module with gate_up_proj/down_proj tensors
|
||||
if detect_fused_experts(experts):
|
||||
gate_up = getattr(experts, "gate_up_proj").data
|
||||
down_fused = getattr(experts, "down_proj").data
|
||||
# gate_up_proj is [E, 2*I, H], split into gate [E, I, H] and up [E, I, H]
|
||||
intermediate = gate_up.shape[1] // 2
|
||||
gate_proj = gate_up[:, :intermediate, :].contiguous()
|
||||
up_proj = gate_up[:, intermediate:, :].contiguous()
|
||||
# down_proj is already [E, H, I]
|
||||
down_proj = down_fused.contiguous()
|
||||
return gate_proj, up_proj, down_proj
|
||||
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
|
||||
gather_params: list[torch.nn.Parameter] = []
|
||||
|
|
@ -92,10 +112,27 @@ def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchCon
|
|||
"""
|
||||
Clear original expert weights to free memory after KT weights are loaded.
|
||||
"""
|
||||
from .arch import detect_fused_experts
|
||||
|
||||
experts = getattr(moe_module, moe_config.experts_attr, None)
|
||||
if experts is None:
|
||||
return
|
||||
|
||||
# Fused format: replace gate_up_proj/down_proj tensors with zero-storage placeholders
|
||||
if detect_fused_experts(experts):
|
||||
for name in ("gate_up_proj", "down_proj"):
|
||||
param = getattr(experts, name, None)
|
||||
if not isinstance(param, torch.nn.Parameter):
|
||||
continue
|
||||
original_dtype = param.dtype
|
||||
tiny_storage = torch.UntypedStorage(1, device="cpu")
|
||||
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
|
||||
tiny_storage, storage_offset=0, size=param.shape,
|
||||
stride=[0] * len(param.shape),
|
||||
)
|
||||
experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False)
|
||||
return
|
||||
|
||||
def _iter_weight_params():
|
||||
for expert in experts:
|
||||
for weight_name in moe_config.weight_names:
|
||||
|
|
|
|||
|
|
@ -264,12 +264,17 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
model_container, layers = _get_model_container_and_layers(model, purpose="wrapping")
|
||||
logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}")
|
||||
|
||||
from .arch import detect_fused_experts as _detect_fused
|
||||
|
||||
for layer_idx, layer in enumerate(layers):
|
||||
moe_module = get_moe_module(layer, moe_config)
|
||||
if moe_module is None:
|
||||
continue
|
||||
|
||||
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method})")
|
||||
_layer_experts = getattr(moe_module, moe_config.experts_attr, None)
|
||||
_layer_is_fused = _detect_fused(_layer_experts)
|
||||
|
||||
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method}, fused={_layer_is_fused})")
|
||||
|
||||
# Only rank 0 loads weights and initializes KT kernel
|
||||
gate_proj, up_proj, down_proj = None, None, None
|
||||
|
|
@ -312,7 +317,6 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
num_experts_per_tok=moe_config.num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_config.intermediate_size,
|
||||
gpu_experts_mask=None,
|
||||
num_gpu_experts=0,
|
||||
cpuinfer_threads=getattr(cfg, "kt_num_threads", 1),
|
||||
threadpool_count=threadpool_count,
|
||||
|
|
@ -370,7 +374,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
layer_idx=layer_idx,
|
||||
lora_experts=lora_experts,
|
||||
)
|
||||
layer_wrapper._skip_lora = "SkipLoRA" in kt_method
|
||||
layer_wrapper._fused_experts = _layer_is_fused
|
||||
layer_wrapper._lora_rank = lora_rank
|
||||
|
||||
setattr(layer, moe_config.moe_layer_attr, layer_wrapper)
|
||||
# Base weights have been copied into the C++ kernel's internal BufferB format.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue