mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
feat(sft): add Qwen3.5 MoE support + fused checkpoint loading
- arch.py: add Qwen3_5Moe arch match, read config from text_config, _get_layers_prefix returns model.language_model.layers for Qwen3.5, _get_model_container_and_layers searches language_model attr - weights.py: load_experts_from_checkpoint_files detects fused format (gate_up_proj in weight_map) and splits into gate/up/down - wrapper.py: hidden_size fallback to text_config Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
58d7eabb9b
commit
dd1da65d90
3 changed files with 54 additions and 20 deletions
|
|
@ -98,16 +98,17 @@ def get_moe_arch_config(config) -> MOEArchConfig:
|
|||
has_shared_experts=getattr(config, "n_shared_experts", 0) > 0,
|
||||
router_type="deepseek_gate",
|
||||
)
|
||||
if "Qwen2Moe" in arch or "Qwen3Moe" in arch:
|
||||
if "Qwen2Moe" in arch or "Qwen3Moe" in arch or "Qwen3_5Moe" in arch:
|
||||
cfg = getattr(config, "text_config", config)
|
||||
return MOEArchConfig(
|
||||
moe_layer_attr="mlp",
|
||||
router_attr="gate",
|
||||
experts_attr="experts",
|
||||
weight_names=("gate_proj", "up_proj", "down_proj"),
|
||||
expert_num=config.num_experts,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
num_experts_per_tok=config.num_experts_per_tok,
|
||||
has_shared_experts=getattr(config, "shared_expert_intermediate_size", 0) > 0,
|
||||
expert_num=cfg.num_experts,
|
||||
intermediate_size=cfg.moe_intermediate_size,
|
||||
num_experts_per_tok=cfg.num_experts_per_tok,
|
||||
has_shared_experts=getattr(cfg, "shared_expert_intermediate_size", 0) > 0,
|
||||
)
|
||||
if "Mixtral" in arch:
|
||||
return MOEArchConfig(
|
||||
|
|
@ -123,7 +124,7 @@ def get_moe_arch_config(config) -> MOEArchConfig:
|
|||
|
||||
raise KTAMXModelNotSupportedError(
|
||||
f"Model architecture {arch} not supported for KT AMX. "
|
||||
"Supported architectures: DeepseekV2, DeepseekV3, Qwen2Moe, Qwen3Moe, Mixtral"
|
||||
"Supported architectures: DeepseekV2, DeepseekV3, Qwen2Moe, Qwen3Moe, Qwen3_5Moe, Mixtral"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -154,8 +155,8 @@ def detect_fused_experts(experts: nn.Module) -> bool:
|
|||
|
||||
def _get_layers_prefix(config) -> str:
|
||||
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
|
||||
if any(x in arch for x in ["Deepseek", "Qwen", "Mixtral", "Llama"]):
|
||||
return "model.layers"
|
||||
if "Qwen3_5Moe" in arch:
|
||||
return "model.language_model.layers"
|
||||
return "model.layers"
|
||||
|
||||
|
||||
|
|
@ -181,7 +182,7 @@ def _get_model_container_and_layers(model: nn.Module, *, purpose: str) -> tuple[
|
|||
if layers is not None and isinstance(layers, (list, tuple, nn.ModuleList)):
|
||||
return current, layers
|
||||
|
||||
for attr in ("model", "base_model", "pretrained_model", "module"):
|
||||
for attr in ("model", "base_model", "pretrained_model", "module", "language_model"):
|
||||
child = getattr(current, attr, None)
|
||||
if isinstance(child, nn.Module) and child is not current:
|
||||
to_visit.append(child)
|
||||
|
|
|
|||
|
|
@ -309,15 +309,23 @@ def load_experts_from_checkpoint_files(
|
|||
weight_map = sharded_metadata.get("weight_map", None)
|
||||
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
keys = []
|
||||
for expert_idx in range(moe_config.expert_num):
|
||||
base = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}.{expert_idx}"
|
||||
keys.append(f"{base}.{gate_name}.weight")
|
||||
keys.append(f"{base}.{gate_name}.weight_scale_inv")
|
||||
keys.append(f"{base}.{up_name}.weight")
|
||||
keys.append(f"{base}.{up_name}.weight_scale_inv")
|
||||
keys.append(f"{base}.{down_name}.weight")
|
||||
keys.append(f"{base}.{down_name}.weight_scale_inv")
|
||||
experts_prefix = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}"
|
||||
fused_gate_up_key = f"{experts_prefix}.gate_up_proj"
|
||||
fused_down_key = f"{experts_prefix}.down_proj"
|
||||
is_fused = weight_map is not None and fused_gate_up_key in weight_map
|
||||
|
||||
if is_fused:
|
||||
keys = [fused_gate_up_key, fused_down_key]
|
||||
else:
|
||||
keys = []
|
||||
for expert_idx in range(moe_config.expert_num):
|
||||
base = f"{experts_prefix}.{expert_idx}"
|
||||
keys.append(f"{base}.{gate_name}.weight")
|
||||
keys.append(f"{base}.{gate_name}.weight_scale_inv")
|
||||
keys.append(f"{base}.{up_name}.weight")
|
||||
keys.append(f"{base}.{up_name}.weight_scale_inv")
|
||||
keys.append(f"{base}.{down_name}.weight")
|
||||
keys.append(f"{base}.{down_name}.weight_scale_inv")
|
||||
|
||||
keys_by_file: dict[str, list[str]] = {}
|
||||
mapped_count = 0
|
||||
|
|
@ -362,6 +370,30 @@ def load_experts_from_checkpoint_files(
|
|||
flush=True,
|
||||
)
|
||||
|
||||
t2 = time.time()
|
||||
if is_fused:
|
||||
gate_up = tensor_map.get(fused_gate_up_key)
|
||||
down = tensor_map.get(fused_down_key)
|
||||
if gate_up is None or down is None:
|
||||
raise FileNotFoundError(f"Missing fused expert weights for layer {layer_idx}")
|
||||
gate_up = gate_up.cpu().to(torch.bfloat16).contiguous()
|
||||
I = gate_up.shape[1] // 2
|
||||
gate_proj = gate_up[:, :I, :].contiguous()
|
||||
up_proj = gate_up[:, I:, :].contiguous()
|
||||
down_proj = down.cpu().to(torch.bfloat16).contiguous()
|
||||
del gate_up
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: fused expert format — "
|
||||
f"split gate_up_proj [{gate_proj.shape}] + down [{down_proj.shape}]",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, "
|
||||
f"shape={gate_proj.shape}, dequant=0.0s, total={time.time()-t0:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
return gate_proj, up_proj, down_proj
|
||||
|
||||
gate_weights = []
|
||||
up_weights = []
|
||||
down_weights = []
|
||||
|
|
@ -369,7 +401,7 @@ def load_experts_from_checkpoint_files(
|
|||
up_scales = []
|
||||
down_scales = []
|
||||
for expert_idx in range(moe_config.expert_num):
|
||||
base = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}.{expert_idx}"
|
||||
base = f"{experts_prefix}.{expert_idx}"
|
||||
gate_key = f"{base}.{gate_name}.weight"
|
||||
up_key = f"{base}.{up_name}.weight"
|
||||
down_key = f"{base}.{down_name}.weight"
|
||||
|
|
|
|||
|
|
@ -157,7 +157,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
is_rank_0 = dist.get_rank() == 0
|
||||
|
||||
moe_config = get_moe_arch_config(model.config)
|
||||
hidden_size = model.config.hidden_size
|
||||
_text_cfg = getattr(model.config, "text_config", model.config)
|
||||
hidden_size = _text_cfg.hidden_size
|
||||
|
||||
cfg = _get_kt_config(kt_plugin)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue