diff --git a/kt-kernel/python/sft/arch.py b/kt-kernel/python/sft/arch.py index 8e60c581..43b2e2cb 100644 --- a/kt-kernel/python/sft/arch.py +++ b/kt-kernel/python/sft/arch.py @@ -98,16 +98,17 @@ def get_moe_arch_config(config) -> MOEArchConfig: has_shared_experts=getattr(config, "n_shared_experts", 0) > 0, router_type="deepseek_gate", ) - if "Qwen2Moe" in arch or "Qwen3Moe" in arch: + if "Qwen2Moe" in arch or "Qwen3Moe" in arch or "Qwen3_5Moe" in arch: + cfg = getattr(config, "text_config", config) return MOEArchConfig( moe_layer_attr="mlp", router_attr="gate", experts_attr="experts", weight_names=("gate_proj", "up_proj", "down_proj"), - expert_num=config.num_experts, - intermediate_size=config.moe_intermediate_size, - num_experts_per_tok=config.num_experts_per_tok, - has_shared_experts=getattr(config, "shared_expert_intermediate_size", 0) > 0, + expert_num=cfg.num_experts, + intermediate_size=cfg.moe_intermediate_size, + num_experts_per_tok=cfg.num_experts_per_tok, + has_shared_experts=getattr(cfg, "shared_expert_intermediate_size", 0) > 0, ) if "Mixtral" in arch: return MOEArchConfig( @@ -123,7 +124,7 @@ def get_moe_arch_config(config) -> MOEArchConfig: raise KTAMXModelNotSupportedError( f"Model architecture {arch} not supported for KT AMX. " - "Supported architectures: DeepseekV2, DeepseekV3, Qwen2Moe, Qwen3Moe, Mixtral" + "Supported architectures: DeepseekV2, DeepseekV3, Qwen2Moe, Qwen3Moe, Qwen3_5Moe, Mixtral" ) @@ -154,8 +155,8 @@ def detect_fused_experts(experts: nn.Module) -> bool: def _get_layers_prefix(config) -> str: arch = config.architectures[0] if getattr(config, "architectures", None) else "" - if any(x in arch for x in ["Deepseek", "Qwen", "Mixtral", "Llama"]): - return "model.layers" + if "Qwen3_5Moe" in arch: + return "model.language_model.layers" return "model.layers" @@ -181,7 +182,7 @@ def _get_model_container_and_layers(model: nn.Module, *, purpose: str) -> tuple[ if layers is not None and isinstance(layers, (list, tuple, nn.ModuleList)): return current, layers - for attr in ("model", "base_model", "pretrained_model", "module"): + for attr in ("model", "base_model", "pretrained_model", "module", "language_model"): child = getattr(current, attr, None) if isinstance(child, nn.Module) and child is not current: to_visit.append(child) diff --git a/kt-kernel/python/sft/weights.py b/kt-kernel/python/sft/weights.py index 207f8e4f..c15e2263 100644 --- a/kt-kernel/python/sft/weights.py +++ b/kt-kernel/python/sft/weights.py @@ -309,15 +309,23 @@ def load_experts_from_checkpoint_files( weight_map = sharded_metadata.get("weight_map", None) gate_name, up_name, down_name = moe_config.weight_names - keys = [] - for expert_idx in range(moe_config.expert_num): - base = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}.{expert_idx}" - keys.append(f"{base}.{gate_name}.weight") - keys.append(f"{base}.{gate_name}.weight_scale_inv") - keys.append(f"{base}.{up_name}.weight") - keys.append(f"{base}.{up_name}.weight_scale_inv") - keys.append(f"{base}.{down_name}.weight") - keys.append(f"{base}.{down_name}.weight_scale_inv") + experts_prefix = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}" + fused_gate_up_key = f"{experts_prefix}.gate_up_proj" + fused_down_key = f"{experts_prefix}.down_proj" + is_fused = weight_map is not None and fused_gate_up_key in weight_map + + if is_fused: + keys = [fused_gate_up_key, fused_down_key] + else: + keys = [] + for expert_idx in range(moe_config.expert_num): + base = f"{experts_prefix}.{expert_idx}" + keys.append(f"{base}.{gate_name}.weight") + keys.append(f"{base}.{gate_name}.weight_scale_inv") + keys.append(f"{base}.{up_name}.weight") + keys.append(f"{base}.{up_name}.weight_scale_inv") + keys.append(f"{base}.{down_name}.weight") + keys.append(f"{base}.{down_name}.weight_scale_inv") keys_by_file: dict[str, list[str]] = {} mapped_count = 0 @@ -362,6 +370,30 @@ def load_experts_from_checkpoint_files( flush=True, ) + t2 = time.time() + if is_fused: + gate_up = tensor_map.get(fused_gate_up_key) + down = tensor_map.get(fused_down_key) + if gate_up is None or down is None: + raise FileNotFoundError(f"Missing fused expert weights for layer {layer_idx}") + gate_up = gate_up.cpu().to(torch.bfloat16).contiguous() + I = gate_up.shape[1] // 2 + gate_proj = gate_up[:, :I, :].contiguous() + up_proj = gate_up[:, I:, :].contiguous() + down_proj = down.cpu().to(torch.bfloat16).contiguous() + del gate_up + print( + f"[kt_moe] Layer {layer_idx}: fused expert format — " + f"split gate_up_proj [{gate_proj.shape}] + down [{down_proj.shape}]", + flush=True, + ) + print( + f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, " + f"shape={gate_proj.shape}, dequant=0.0s, total={time.time()-t0:.1f}s", + flush=True, + ) + return gate_proj, up_proj, down_proj + gate_weights = [] up_weights = [] down_weights = [] @@ -369,7 +401,7 @@ def load_experts_from_checkpoint_files( up_scales = [] down_scales = [] for expert_idx in range(moe_config.expert_num): - base = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}.{expert_idx}" + base = f"{experts_prefix}.{expert_idx}" gate_key = f"{base}.{gate_name}.weight" up_key = f"{base}.{up_name}.weight" down_key = f"{base}.{down_name}.weight" diff --git a/kt-kernel/python/sft/wrapper.py b/kt-kernel/python/sft/wrapper.py index 4b29bfd7..06706716 100644 --- a/kt-kernel/python/sft/wrapper.py +++ b/kt-kernel/python/sft/wrapper.py @@ -157,7 +157,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT is_rank_0 = dist.get_rank() == 0 moe_config = get_moe_arch_config(model.config) - hidden_size = model.config.hidden_size + _text_cfg = getattr(model.config, "text_config", model.config) + hidden_size = _text_cfg.hidden_size cfg = _get_kt_config(kt_plugin)