mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 12:19:50 +00:00
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
* feat(sft): AMX MoE SFT backend with LoRA support Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally. * refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code - KTConfig fields all use kt_ prefix matching dict keys — eliminates _OLD_TO_NEW mapping and prefix-stripping in wrapper.py - Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing is on (via training_args.py), flows through to C++ cache allocation - Remove dead checkpoint detection code: in_ckpt_recompute, in_ckpt_first_forward vars (assigned but never read), fallback _is_in_checkpoint_first_forward() function, unused inspect import - Remove redundant env var fallbacks in wrapper.py for share_backward_bb and share_cache_pool (KTConfig.__post_init__ already handles env vars) - Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check Verified: Qwen3-235B 3-step training on sap4, loss matches baseline (1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor(sft): share_backward_bb default True, share_cache_pool auto-derived - kt_share_backward_bb defaults to True (always saves memory) - kt_share_cache_pool no longer reads from env var; defaults False, auto-set to True by trainer_config_process when gradient checkpointing is enabled Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument, but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(sft): support transformers v5 fused expert format Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters (gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert nn.Linear modules. PEFT cannot attach LoRA to these, so we create KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers for the optimizer, and pre-assigned .grad for C++ backward. - arch.py: detect_fused_experts() detection - weights.py: fused format extraction and weight clearing - wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank - lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA, get_kt_lora_params collects fused params, deduplicate wrapper finding - layer.py: handle v5 TopKRouter tuple output, remove dead code - autograd.py: sync_forward_sft/submit_forward_sft API rename Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(sft): add Qwen3.5 MoE support + fused checkpoint loading - arch.py: add Qwen3_5Moe arch match, read config from text_config, _get_layers_prefix returns model.language_model.layers for Qwen3.5, _get_model_container_and_layers searches language_model attr - weights.py: load_experts_from_checkpoint_files detects fused format (gate_up_proj in weight_map) and splits into gate/up/down - wrapper.py: hidden_size fallback to text_config Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * [fix](sft): align Python API with C++ backend after v5 refactor - wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature) - layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward - autograd.py: rename sync_forward_sft to sync_forward The sft-v5 refactor (commits58d7eab,dd1da65) renamed Python-side method calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original method names. This caused AttributeError on Qwen3.5-35B and other models. * align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults - Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace, sft_timer namespace, ITT API, extended do_work_stealing_job API) - Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp, moe-sft-tp.hpp, avx_kernels.hpp) - Restore pin_memory=True in KExpertsCPUBuffer (inference path) - Restore fused tensor transpose logic in convert_cpu_weights.py (main layout) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * revert CMakeLists.txt to main: remove debug flags and cpptrace dep Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * clean up dev artifacts: remove SFT design docs, debug examples, bench scripts Remove files not needed in the merge: - docs/SFT+KTWrapper/ (6 Chinese design docs) - docs/sft_moe_amx/ (21 dev/debug docs) - 12 debug/test example scripts - 6 SFT-specific bench scripts and report Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
282 lines
9.4 KiB
Python
282 lines
9.4 KiB
Python
# MoE architecture configuration and model utilities
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""
|
|
MoE architecture detection and model navigation utilities.
|
|
|
|
This is a leaf module — no imports from other sft/ submodules.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# Exceptions
|
|
# =============================================================================
|
|
|
|
|
|
class KTAMXError(Exception):
|
|
"""Base exception for KT AMX errors."""
|
|
|
|
|
|
class KTAMXNotAvailableError(KTAMXError):
|
|
"""kt_kernel not installed or AMX not supported."""
|
|
|
|
|
|
class KTAMXModelNotSupportedError(KTAMXError):
|
|
"""Model architecture not supported."""
|
|
|
|
|
|
class KTAMXConfigError(KTAMXError):
|
|
"""Configuration error."""
|
|
|
|
|
|
# =============================================================================
|
|
# MoE Configuration
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class MOEArchConfig:
|
|
"""MoE architecture configuration for different model types."""
|
|
|
|
moe_layer_attr: str
|
|
router_attr: str
|
|
experts_attr: str
|
|
weight_names: tuple[str, str, str]
|
|
expert_num: int
|
|
intermediate_size: int
|
|
num_experts_per_tok: int
|
|
has_shared_experts: bool = False
|
|
router_type: str = "linear"
|
|
|
|
|
|
def get_moe_arch_config(config) -> MOEArchConfig:
|
|
"""
|
|
Get MoE architecture configuration based on model type.
|
|
|
|
Args:
|
|
config: HuggingFace model configuration
|
|
|
|
Returns:
|
|
MOEArchConfig for the model
|
|
|
|
Raises:
|
|
KTAMXModelNotSupportedError: If model architecture is not supported
|
|
"""
|
|
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
|
|
|
|
if "DeepseekV2" in arch:
|
|
return MOEArchConfig(
|
|
moe_layer_attr="mlp",
|
|
router_attr="gate",
|
|
experts_attr="experts",
|
|
weight_names=("gate_proj", "up_proj", "down_proj"),
|
|
expert_num=config.n_routed_experts,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
num_experts_per_tok=config.num_experts_per_tok,
|
|
has_shared_experts=getattr(config, "n_shared_experts", 0) > 0,
|
|
router_type="deepseek_gate",
|
|
)
|
|
if "DeepseekV3" in arch:
|
|
return MOEArchConfig(
|
|
moe_layer_attr="mlp",
|
|
router_attr="gate",
|
|
experts_attr="experts",
|
|
weight_names=("gate_proj", "up_proj", "down_proj"),
|
|
expert_num=config.n_routed_experts,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
num_experts_per_tok=config.num_experts_per_tok,
|
|
has_shared_experts=getattr(config, "n_shared_experts", 0) > 0,
|
|
router_type="deepseek_gate",
|
|
)
|
|
if "Qwen2Moe" in arch or "Qwen3Moe" in arch or "Qwen3_5Moe" in arch:
|
|
cfg = getattr(config, "text_config", config)
|
|
return MOEArchConfig(
|
|
moe_layer_attr="mlp",
|
|
router_attr="gate",
|
|
experts_attr="experts",
|
|
weight_names=("gate_proj", "up_proj", "down_proj"),
|
|
expert_num=cfg.num_experts,
|
|
intermediate_size=cfg.moe_intermediate_size,
|
|
num_experts_per_tok=cfg.num_experts_per_tok,
|
|
has_shared_experts=getattr(cfg, "shared_expert_intermediate_size", 0) > 0,
|
|
)
|
|
if "Mixtral" in arch:
|
|
return MOEArchConfig(
|
|
moe_layer_attr="block_sparse_moe",
|
|
router_attr="gate",
|
|
experts_attr="experts",
|
|
weight_names=("w1", "w3", "w2"),
|
|
expert_num=config.num_local_experts,
|
|
intermediate_size=config.intermediate_size,
|
|
num_experts_per_tok=config.num_experts_per_tok,
|
|
has_shared_experts=False,
|
|
)
|
|
|
|
raise KTAMXModelNotSupportedError(
|
|
f"Model architecture {arch} not supported for KT AMX. "
|
|
"Supported architectures: DeepseekV2, DeepseekV3, Qwen2Moe, Qwen3Moe, Qwen3_5Moe, Mixtral"
|
|
)
|
|
|
|
|
|
def get_moe_module(layer: nn.Module, moe_config: MOEArchConfig) -> nn.Module | None:
|
|
"""Get MoE module from transformer layer."""
|
|
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
|
|
if moe_module is None:
|
|
return None
|
|
if not hasattr(moe_module, moe_config.experts_attr):
|
|
return None
|
|
return moe_module
|
|
|
|
|
|
def detect_fused_experts(experts: nn.Module) -> bool:
|
|
"""Detect if experts module uses the transformers v5 fused format.
|
|
|
|
Fused format: a single Module with ``gate_up_proj`` [E, 2I, H] and
|
|
``down_proj`` [E, H, I] 3-D tensors instead of a ModuleList of Linear experts.
|
|
"""
|
|
if experts is None:
|
|
return False
|
|
gate_up = getattr(experts, "gate_up_proj", None)
|
|
down = getattr(experts, "down_proj", None)
|
|
if isinstance(gate_up, torch.Tensor) and isinstance(down, torch.Tensor):
|
|
return gate_up.dim() == 3 and down.dim() == 3
|
|
return False
|
|
|
|
|
|
def _get_layers_prefix(config) -> str:
|
|
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
|
|
if "Qwen3_5Moe" in arch:
|
|
return "model.language_model.layers"
|
|
return "model.layers"
|
|
|
|
|
|
def _get_model_container_and_layers(model: nn.Module, *, purpose: str) -> tuple[nn.Module, any]:
|
|
"""
|
|
Resolve the transformer layer container for KT integration.
|
|
|
|
KT expects the transformer block stack to be accessible as `<container>.layers`.
|
|
Handles PEFT PeftModel, TRL value-head models, DDP wrappers.
|
|
"""
|
|
to_visit: list[nn.Module] = [model]
|
|
visited: set[int] = set()
|
|
visited_types: list[str] = []
|
|
|
|
while to_visit:
|
|
current = to_visit.pop(0)
|
|
if id(current) in visited:
|
|
continue
|
|
visited.add(id(current))
|
|
visited_types.append(type(current).__name__)
|
|
|
|
layers = getattr(current, "layers", None)
|
|
if layers is not None and isinstance(layers, (list, tuple, nn.ModuleList)):
|
|
return current, layers
|
|
|
|
for attr in ("model", "base_model", "pretrained_model", "module", "language_model"):
|
|
child = getattr(current, attr, None)
|
|
if isinstance(child, nn.Module) and child is not current:
|
|
to_visit.append(child)
|
|
|
|
get_base_model = getattr(current, "get_base_model", None)
|
|
if callable(get_base_model):
|
|
try:
|
|
base = get_base_model()
|
|
except Exception:
|
|
base = None
|
|
if isinstance(base, nn.Module) and base is not current:
|
|
to_visit.append(base)
|
|
|
|
visited_preview = ", ".join(visited_types[:6])
|
|
if len(visited_types) > 6:
|
|
visited_preview += ", ..."
|
|
|
|
raise KTAMXConfigError(
|
|
f"Model does not expose a .model.layers or .layers attribute for KT {purpose}. "
|
|
"Tried unwrapping via model/base_model/pretrained_model/module/get_base_model; "
|
|
f"visited: {visited_preview}"
|
|
)
|
|
|
|
|
|
def move_non_experts_to_gpu(
|
|
model: nn.Module,
|
|
moe_config: MOEArchConfig | None = None,
|
|
device: str = "cuda:0",
|
|
) -> None:
|
|
"""Move non-expert parameters to GPU after loading (experts stay on CPU)."""
|
|
if moe_config is None:
|
|
config = getattr(model, "config", None)
|
|
if config is None:
|
|
raise KTAMXConfigError("Model config is required to infer MoE architecture.")
|
|
moe_config = get_moe_arch_config(config)
|
|
|
|
container, layers = _get_model_container_and_layers(model, purpose="placement")
|
|
|
|
if hasattr(container, "embed_tokens"):
|
|
container.embed_tokens.to(device)
|
|
if hasattr(container, "norm"):
|
|
container.norm.to(device)
|
|
if hasattr(model, "lm_head"):
|
|
model.lm_head.to(device)
|
|
|
|
for layer in layers:
|
|
if hasattr(layer, "self_attn"):
|
|
layer.self_attn.to(device)
|
|
|
|
if hasattr(layer, "input_layernorm"):
|
|
layer.input_layernorm.to(device)
|
|
if hasattr(layer, "post_attention_layernorm"):
|
|
layer.post_attention_layernorm.to(device)
|
|
|
|
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
|
|
if moe_module is None or not hasattr(moe_module, moe_config.experts_attr):
|
|
if hasattr(layer, "mlp"):
|
|
layer.mlp.to(device)
|
|
continue
|
|
|
|
router = getattr(moe_module, moe_config.router_attr, None)
|
|
if router is not None:
|
|
router.to(device)
|
|
|
|
if hasattr(moe_module, "shared_experts") and moe_module.shared_experts is not None:
|
|
moe_module.shared_experts.to(device)
|
|
|
|
logger.info(f"Moved non-expert parameters to {device}")
|
|
|
|
|
|
def get_expert_device(model: nn.Module, moe_config: MOEArchConfig | None = None) -> str:
|
|
"""Get the device type of MoE experts."""
|
|
if moe_config is None:
|
|
config = getattr(model, "config", None)
|
|
if config is None:
|
|
return "unknown"
|
|
moe_config = get_moe_arch_config(config)
|
|
|
|
try:
|
|
_, layers = _get_model_container_and_layers(model, purpose="expert device probing")
|
|
except KTAMXConfigError:
|
|
return "unknown"
|
|
|
|
for layer in layers:
|
|
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
|
|
if moe_module is None:
|
|
continue
|
|
experts = getattr(moe_module, moe_config.experts_attr, None)
|
|
if not experts:
|
|
continue
|
|
first_expert = experts[0]
|
|
gate_name = moe_config.weight_names[0]
|
|
gate_proj = getattr(first_expert, gate_name, None)
|
|
if gate_proj is not None:
|
|
return str(gate_proj.weight.device.type)
|
|
|
|
return "unknown"
|