feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
mrhaoxx 2026-04-20 17:19:15 +08:00
parent 58d7eabb9b
commit dd1da65d90
No known key found for this signature in database
3 changed files with 54 additions and 20 deletions

View file

@ -309,15 +309,23 @@ def load_experts_from_checkpoint_files(
weight_map = sharded_metadata.get("weight_map", None)
gate_name, up_name, down_name = moe_config.weight_names
keys = []
for expert_idx in range(moe_config.expert_num):
base = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}.{expert_idx}"
keys.append(f"{base}.{gate_name}.weight")
keys.append(f"{base}.{gate_name}.weight_scale_inv")
keys.append(f"{base}.{up_name}.weight")
keys.append(f"{base}.{up_name}.weight_scale_inv")
keys.append(f"{base}.{down_name}.weight")
keys.append(f"{base}.{down_name}.weight_scale_inv")
experts_prefix = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}"
fused_gate_up_key = f"{experts_prefix}.gate_up_proj"
fused_down_key = f"{experts_prefix}.down_proj"
is_fused = weight_map is not None and fused_gate_up_key in weight_map
if is_fused:
keys = [fused_gate_up_key, fused_down_key]
else:
keys = []
for expert_idx in range(moe_config.expert_num):
base = f"{experts_prefix}.{expert_idx}"
keys.append(f"{base}.{gate_name}.weight")
keys.append(f"{base}.{gate_name}.weight_scale_inv")
keys.append(f"{base}.{up_name}.weight")
keys.append(f"{base}.{up_name}.weight_scale_inv")
keys.append(f"{base}.{down_name}.weight")
keys.append(f"{base}.{down_name}.weight_scale_inv")
keys_by_file: dict[str, list[str]] = {}
mapped_count = 0
@ -362,6 +370,30 @@ def load_experts_from_checkpoint_files(
flush=True,
)
t2 = time.time()
if is_fused:
gate_up = tensor_map.get(fused_gate_up_key)
down = tensor_map.get(fused_down_key)
if gate_up is None or down is None:
raise FileNotFoundError(f"Missing fused expert weights for layer {layer_idx}")
gate_up = gate_up.cpu().to(torch.bfloat16).contiguous()
I = gate_up.shape[1] // 2
gate_proj = gate_up[:, :I, :].contiguous()
up_proj = gate_up[:, I:, :].contiguous()
down_proj = down.cpu().to(torch.bfloat16).contiguous()
del gate_up
print(
f"[kt_moe] Layer {layer_idx}: fused expert format — "
f"split gate_up_proj [{gate_proj.shape}] + down [{down_proj.shape}]",
flush=True,
)
print(
f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, "
f"shape={gate_proj.shape}, dequant=0.0s, total={time.time()-t0:.1f}s",
flush=True,
)
return gate_proj, up_proj, down_proj
gate_weights = []
up_weights = []
down_weights = []
@ -369,7 +401,7 @@ def load_experts_from_checkpoint_files(
up_scales = []
down_scales = []
for expert_idx in range(moe_config.expert_num):
base = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}.{expert_idx}"
base = f"{experts_prefix}.{expert_idx}"
gate_key = f"{base}.{gate_name}.weight"
up_key = f"{base}.{up_name}.weight"
down_key = f"{base}.{down_name}.weight"