feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
This commit is contained in:
mrhaoxx 2026-04-08 23:11:00 +08:00
parent ddb957596f
commit f36699affd
84 changed files with 51278 additions and 623 deletions

View file

@ -0,0 +1,610 @@
# Model wrapping entry points for SFT
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import gc
import importlib.util as _u
import logging
import os
from typing import Any
import torch
import torch.nn as nn
from .arch import (
KTAMXConfigError,
KTAMXNotAvailableError,
MOEArchConfig,
_get_layers_prefix,
_get_model_container_and_layers,
get_moe_arch_config,
get_moe_module,
)
from .layer import KTMoELayerWrapper
from .lora import LoRAExperts
from .weights import (
_clear_original_expert_weights,
extract_moe_weights,
load_experts_from_checkpoint_files,
)
logger = logging.getLogger(__name__)
KT_KERNEL_AVAILABLE = _u.find_spec("kt_kernel") is not None
if KT_KERNEL_AVAILABLE:
try:
from kt_kernel.experts import KTMoEWrapper
except Exception:
KTMoEWrapper = None
KT_KERNEL_AVAILABLE = False
else:
KTMoEWrapper = None
# =============================================================================
# Device-map builders
# =============================================================================
def _get_kt_config(kt_plugin: Any):
"""Extract KTConfig from a KTransformersPlugin or compatible object.
Handles three cases:
1. KTransformersPlugin with .kt_config (new style) return kt_config
2. Object with old field names (kt_num_threads etc.) convert to KTConfig
3. KTConfig directly return as-is
"""
from .config import KTConfig
# New-style KTransformersPlugin
kt_config = getattr(kt_plugin, "kt_config", None)
if kt_config is not None and isinstance(kt_config, KTConfig):
return kt_config
# Already a KTConfig
if isinstance(kt_plugin, KTConfig):
return kt_plugin
# Old-style object (HfTrainerKTConfig, old KTransformersPlugin, dict-like) — convert
# Map old field names (kt_xxx) to new field names (xxx)
_OLD_TO_NEW = {
"kt_backend": "backend", "kt_num_threads": "num_threads",
"kt_tp_enabled": "tp_enabled", "kt_threadpool_count": "threadpool_count",
"kt_weight_path": "weight_path", "kt_expert_checkpoint_path": "expert_checkpoint_path",
"kt_num_gpu_experts": "num_gpu_experts", "kt_max_cache_depth": "max_cache_depth",
"kt_use_lora_experts": "use_lora_experts", "kt_lora_expert_num": "lora_expert_num",
"kt_lora_expert_intermediate_size": "lora_expert_intermediate_size",
"kt_skip_expert_loading": "skip_expert_loading",
"kt_share_backward_bb": "share_backward_bb",
"kt_checkpoint_files": "checkpoint_files",
"kt_sharded_metadata": "sharded_metadata",
}
kwargs = {}
for old_name, new_name in _OLD_TO_NEW.items():
val = getattr(kt_plugin, old_name, None)
if val is not None:
kwargs[new_name] = val
# Fields that don't have kt_ prefix
for name in ("lora_rank", "lora_alpha", "model_max_length", "wrap_fn", "wrap_kwargs"):
val = getattr(kt_plugin, name, None)
if val is not None:
kwargs[name] = val
return KTConfig(**kwargs)
def build_kt_device_map(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]:
"""
Build device_map for KT model loading with hybrid GPU/CPU expert placement.
"""
moe_config = get_moe_arch_config(config)
layers_prefix = _get_layers_prefix(config)
num_layers = config.num_hidden_layers
num_experts = moe_config.expert_num
cfg = _get_kt_config(kt_plugin)
num_gpu_experts = getattr(cfg, "num_gpu_experts", 0) or 0
device_map: dict[str, str | int] = {}
device_map["model.embed_tokens"] = device
device_map["model.norm"] = device
device_map["lm_head"] = device
for layer_idx in range(num_layers):
layer_prefix = f"{layers_prefix}.{layer_idx}"
device_map[layer_prefix] = device
moe_prefix = f"{layer_prefix}.{moe_config.moe_layer_attr}"
for expert_idx in range(num_experts):
expert_key = f"{moe_prefix}.{moe_config.experts_attr}.{expert_idx}"
if expert_idx < num_gpu_experts:
device_map[expert_key] = device
else:
device_map[expert_key] = "cpu"
logger.info(
f"Built KT device_map: {num_gpu_experts} GPU experts, {num_experts - num_gpu_experts} CPU experts"
)
return device_map
def build_kt_device_map_simplified(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]:
"""
Simplified device_map builder: map full layers to GPU, override routed experts to CPU.
"""
moe_config = get_moe_arch_config(config)
layers_prefix = _get_layers_prefix(config)
num_layers = config.num_hidden_layers
cfg = _get_kt_config(kt_plugin)
num_gpu_experts = getattr(cfg, "num_gpu_experts", 0) or 0
device_map: dict[str, str | int] = {}
device_map["model.embed_tokens"] = device
device_map["model.norm"] = device
device_map["lm_head"] = device
for layer_idx in range(num_layers):
layer_prefix = f"{layers_prefix}.{layer_idx}"
device_map[layer_prefix] = device
experts_prefix = f"{layer_prefix}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}"
if num_gpu_experts == 0:
device_map[experts_prefix] = "cpu"
else:
return build_kt_device_map(config, kt_plugin, device=device)
logger.info("Built simplified KT device_map: all layers on GPU, routed experts on CPU")
return device_map
# =============================================================================
# MoE layer wrapping
# =============================================================================
def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KTMoELayerWrapper]:
"""
Replace model's MoE layers with KTMoEWrapper-based wrappers.
Loads expert weights into the C++ KT kernel. No LoRA initialization ---
LoRA is handled by PEFT and later adapted via kt_adapt_peft_lora().
Only rank 0 initializes KT kernel and loads weights.
"""
import torch.distributed as dist
if not KT_KERNEL_AVAILABLE:
raise KTAMXNotAvailableError("kt_kernel not found. Please install kt_kernel to enable KT MoE support.")
# Only rank 0 should initialize KT and load weights
is_rank_0 = True
if dist.is_initialized():
is_rank_0 = dist.get_rank() == 0
moe_config = get_moe_arch_config(model.config)
hidden_size = model.config.hidden_size
cfg = _get_kt_config(kt_plugin)
# Read lora_rank/lora_alpha for C++ wrapper initialization (buffer allocation only)
lora_rank = getattr(cfg, "lora_rank", 1) or 1
lora_alpha = getattr(cfg, "lora_alpha", 1.0) or 1.0
# Read LoRA Experts configuration
_raw_le = getattr(cfg, "use_lora_experts", None)
use_lora_experts = bool(_raw_le) if _raw_le is not None else False
lora_expert_num = getattr(cfg, "lora_expert_num", 2) or 2
lora_expert_intermediate_size = getattr(cfg, "lora_expert_intermediate_size", 1024) or 1024
if is_rank_0:
logger.info(
f"LoRA Experts config: use_lora_experts={use_lora_experts}, "
f"num={lora_expert_num}, intermediate_size={lora_expert_intermediate_size}"
)
wrappers: list[KTMoELayerWrapper] = []
moe_layer_count = 0
kt_backend_map = {
"AMXBF16": "AMXBF16_SFT",
"AMXINT8": "AMXINT8_SFT",
"AMXINT4": "AMXINT4_SFT",
"AMXBF16_SkipLoRA": "AMXBF16_SFT_SkipLoRA",
"AMXINT8_SkipLoRA": "AMXINT8_SFT_SkipLoRA",
"AMXINT4_SkipLoRA": "AMXINT4_SFT_SkipLoRA",
}
# Build case-insensitive lookup to handle common typos like "SkipLora" vs "SkipLoRA"
_kt_backend_map_lower = {k.lower(): v for k, v in kt_backend_map.items()}
kt_backend = getattr(cfg, "backend", "AMXBF16")
kt_method = kt_backend_map.get(kt_backend) or _kt_backend_map_lower.get(kt_backend.lower(), "AMXBF16_SFT")
if kt_method != kt_backend_map.get(kt_backend):
logger.warning(
f"kt_backend '{kt_backend}' matched via case-insensitive lookup -> '{kt_method}'. "
f"Please use the exact name from: {list(kt_backend_map.keys())}"
)
if "SkipLoRA" in kt_method:
logger.info(f"Using SkipLoRA backend: {kt_method} (MoE LoRA gradients will be skipped)")
threadpool_count = getattr(cfg, "threadpool_count", 1) if getattr(cfg, "tp_enabled", False) else 1
kt_weight_path = getattr(cfg, "weight_path", None)
use_kt_weight_path = kt_weight_path is not None
if use_kt_weight_path:
logger.info(f"Loading INT8 weights from kt_weight_path: {kt_weight_path}")
checkpoint_files = getattr(cfg, "checkpoint_files", None)
sharded_metadata = getattr(cfg, "sharded_metadata", None)
# When kt_expert_checkpoint_path is set, always resolve from it (overrides any existing
# checkpoint_files which may come from AttnOnlyBf16 and lack expert weights).
kt_expert_checkpoint_path = getattr(cfg, "expert_checkpoint_path", None)
if kt_expert_checkpoint_path:
logger.info(f"Resolving expert checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}")
resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=kt_expert_checkpoint_path)
if resolved_files and all(f.endswith(".safetensors") for f in resolved_files):
checkpoint_files = resolved_files
sharded_metadata = resolved_meta
cfg.checkpoint_files = checkpoint_files
cfg.sharded_metadata = sharded_metadata
logger.info(f"Resolved {len(checkpoint_files)} checkpoint files from kt_expert_checkpoint_path")
else:
logger.warning(f"Failed to resolve checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}")
use_checkpoint_files = bool(checkpoint_files) and not use_kt_weight_path
logger.debug(
f"Weight source: kt_weight_path={kt_weight_path!r}, "
f"kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}, "
f"checkpoint_files count={len(checkpoint_files) if checkpoint_files else 0}, "
f"use_kt_weight_path={use_kt_weight_path}, use_checkpoint_files={use_checkpoint_files}"
)
if use_checkpoint_files:
logger.info("Loading expert weights from checkpoint files (online conversion).")
elif use_kt_weight_path and bool(checkpoint_files):
logger.info("BF16 checkpoint files available for backward gradient computation.")
elif (not use_kt_weight_path) and bool(getattr(cfg, "skip_expert_loading", False)):
# If HF expert weights were skipped during `from_pretrained`, we must source expert weights externally.
model_name_or_path = getattr(getattr(model, "config", None), "name_or_path", None)
if model_name_or_path:
resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=model_name_or_path)
if resolved_files and all(f.endswith(".safetensors") for f in resolved_files):
checkpoint_files = resolved_files
sharded_metadata = resolved_meta
cfg.checkpoint_files = checkpoint_files
cfg.sharded_metadata = sharded_metadata
use_checkpoint_files = True
logger.info("KT skip_expert_loading enabled; using checkpoint files for online expert loading.")
if not use_checkpoint_files:
raise KTAMXConfigError(
"KT skip_expert_loading is enabled but no `kt_weight_path` was provided and no safetensors checkpoint "
"files could be resolved for on-the-fly expert loading."
)
import torch.distributed as _dist
_rank = _dist.get_rank() if _dist.is_initialized() else 0
model_container, layers = _get_model_container_and_layers(model, purpose="wrapping")
logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}")
for layer_idx, layer in enumerate(layers):
moe_module = get_moe_module(layer, moe_config)
if moe_module is None:
continue
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method})")
# Only rank 0 loads weights and initializes KT kernel
gate_proj, up_proj, down_proj = None, None, None
wrapper = None
if is_rank_0:
# Get block_size from quantization_config if available (for FP8 dequant)
_quant_cfg = getattr(model.config, "quantization_config", None)
_block_size = None
if _quant_cfg is not None:
_block_size = getattr(_quant_cfg, "weight_block_size", None)
if use_kt_weight_path:
logger.debug(f"Layer {layer_idx}: forward + backward from kt_weight_path (.kt files)")
elif use_checkpoint_files:
layers_prefix = _get_layers_prefix(model.config)
gate_proj, up_proj, down_proj = load_experts_from_checkpoint_files(
checkpoint_files=checkpoint_files,
sharded_metadata=sharded_metadata,
layers_prefix=layers_prefix,
moe_config=moe_config,
layer_idx=layer_idx,
block_size=_block_size,
)
else:
gate_proj, up_proj, down_proj = extract_moe_weights(moe_module, moe_config)
gate_proj = gate_proj.cpu().to(torch.bfloat16).contiguous()
up_proj = up_proj.cpu().to(torch.bfloat16).contiguous()
down_proj = down_proj.cpu().to(torch.bfloat16).contiguous()
chunked_prefill_size = getattr(cfg, "model_max_length", None)
if chunked_prefill_size is None:
chunked_prefill_size = getattr(model.config, "max_position_embeddings", 4096)
# Only rank 0 creates KTMoEWrapper and loads weights
if is_rank_0:
wrapper = KTMoEWrapper(
layer_idx=layer_idx,
num_experts=moe_config.expert_num,
num_experts_per_tok=moe_config.num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_config.intermediate_size,
num_gpu_experts=0,
cpuinfer_threads=getattr(cfg, "num_threads", 1),
threadpool_count=threadpool_count,
weight_path=kt_weight_path or "",
chunked_prefill_size=chunked_prefill_size,
method=kt_method,
mode="sft",
lora_rank=lora_rank,
lora_alpha=lora_alpha,
max_cache_depth=getattr(cfg, "max_cache_depth", 2),
)
# Set share_backward_bb BEFORE load_weights (config is built during load)
share_backward_bb = getattr(cfg, "share_backward_bb", None)
if share_backward_bb is None:
share_backward_bb = os.environ.get("ACCELERATE_KT_SHARE_BACKWARD_BB", "").lower() in ("true", "1", "yes")
wrapper.share_backward_bb = share_backward_bb
physical_to_logical_map = torch.arange(moe_config.expert_num, dtype=torch.int64, device="cpu")
if use_kt_weight_path:
logger.debug(f"Layer {layer_idx}: calling wrapper.load_weights() (C++ direct .kt load)")
wrapper.load_weights(physical_to_logical_map)
else:
logger.debug(
f"Layer {layer_idx}: calling wrapper.load_weights_from_tensors() "
f"(BF16 tensor path, gate_proj shape={gate_proj.shape if gate_proj is not None else None})"
)
wrapper.load_weights_from_tensors(
gate_proj=gate_proj,
up_proj=up_proj,
down_proj=down_proj,
physical_to_logical_map_cpu=physical_to_logical_map,
)
wrapper.gate_proj = None
wrapper.up_proj = None
wrapper.down_proj = None
# Create LoRA Experts if enabled
lora_experts = None
if use_lora_experts:
lora_experts = LoRAExperts(
num_experts=lora_expert_num,
hidden_size=hidden_size,
intermediate_size=lora_expert_intermediate_size,
device="cuda",
dtype=torch.bfloat16,
)
layer_wrapper = KTMoELayerWrapper(
original_moe=moe_module,
wrapper=wrapper,
lora_params=None,
moe_config=moe_config,
hidden_size=hidden_size,
layer_idx=layer_idx,
lora_experts=lora_experts,
)
layer_wrapper._skip_lora = "SkipLoRA" in kt_method
setattr(layer, moe_config.moe_layer_attr, layer_wrapper)
# Base weights have been copied into the C++ kernel's internal BufferB format.
# Do not hold a Python-side reference --- it wastes ~1 GB/layer.
del gate_proj, up_proj, down_proj
wrappers.append(layer_wrapper)
moe_layer_count += 1
# Replace original expert weights with meta placeholders.
# Experts remain in the model tree (via wrapper.experts) so PEFT can discover them.
# Rank 0 already copied weights to C++ kernel via load_weights_from_tensors.
_clear_original_expert_weights(moe_module, moe_config)
logger.info(f"Wrapped {moe_layer_count} MoE layers with KTMoEWrapper")
# Link wrappers for async backward repack (higher layer triggers repack for lower)
for i in range(1, len(wrappers)):
if wrappers[i].wrapper is not None and wrappers[i - 1].wrapper is not None:
wrappers[i].wrapper._next_backward_wrapper = wrappers[i - 1].wrapper
if wrappers and wrappers[0].wrapper is not None:
wrappers[0].wrapper._next_backward_wrapper = None
gc.collect()
return wrappers
# =============================================================================
# Plugin builder
# =============================================================================
def _build_kt_plugin_from_args(model_args: Any, finetuning_args: Any | None = None):
"""
Build a KTransformersPlugin from model_args and optional finetuning_args.
Imported here to avoid circular dependency --- callers that need the plugin
class should import it from the appropriate dataclasses module.
"""
from .config import KTConfig
from accelerate.utils.dataclasses import KTransformersPlugin
kt_config = KTConfig(
backend=getattr(model_args, "kt_backend", None),
num_threads=getattr(model_args, "kt_num_threads", None),
tp_enabled=getattr(model_args, "kt_tp_enabled", None),
threadpool_count=getattr(model_args, "kt_threadpool_count", None),
max_cache_depth=getattr(model_args, "kt_max_cache_depth", None),
num_gpu_experts=getattr(model_args, "kt_num_gpu_experts", None),
weight_path=getattr(model_args, "kt_weight_path", None),
expert_checkpoint_path=getattr(model_args, "kt_expert_checkpoint_path", None),
use_lora_experts=getattr(model_args, "kt_use_lora_experts", None),
lora_expert_num=getattr(model_args, "kt_lora_expert_num", None),
lora_expert_intermediate_size=getattr(model_args, "kt_lora_expert_intermediate_size", None),
lora_rank=getattr(finetuning_args, "lora_rank", None) if finetuning_args else None,
lora_alpha=getattr(finetuning_args, "lora_alpha", None) if finetuning_args else None,
model_max_length=getattr(model_args, "model_max_length", None),
)
return KTransformersPlugin(enabled=True, kt_config=kt_config)
def get_kt_loading_kwargs(
config,
kt_plugin,
torch_dtype: torch.dtype | str | None = torch.bfloat16,
trust_remote_code: bool | None = None,
token: str | None = None,
) -> dict[str, Any]:
"""Get kwargs for AutoModel.from_pretrained() for KT loading."""
kwargs: dict[str, Any] = {
"config": config,
"torch_dtype": torch_dtype,
"device_map": "cpu",
"low_cpu_mem_usage": True,
}
if trust_remote_code is not None:
kwargs["trust_remote_code"] = trust_remote_code
if token is not None:
kwargs["token"] = token
return kwargs
def _resolve_checkpoint_files(
model_name_or_path: str,
cache_dir: str | None = None,
revision: str | None = None,
token: str | None = None,
trust_remote_code: bool | None = None,
) -> tuple[list[str] | None, dict | None]:
"""Resolve HF checkpoint files. Depends on transformers internals."""
try:
from transformers.modeling_utils import _get_resolved_checkpoint_files
except Exception:
return None, None
try:
checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
pretrained_model_name_or_path=model_name_or_path,
subfolder="",
variant=None,
gguf_file=None,
from_tf=False,
from_flax=False,
use_safetensors=None,
cache_dir=cache_dir,
force_download=False,
proxies=None,
local_files_only=False,
token=token,
user_agent={"file_type": "model", "framework": "pytorch"},
revision=revision or "main",
commit_hash=None,
is_remote_code=bool(trust_remote_code),
transformers_explicit_filename=None,
)
except Exception:
return None, None
return checkpoint_files, sharded_metadata
def load_kt_model(
config,
model_args: Any | None = None,
finetuning_args: Any | None = None,
kt_plugin=None,
model_name_or_path: str | None = None,
trust_remote_code: bool | None = None,
token: str | None = None,
torch_dtype: torch.dtype | str | None = torch.bfloat16,
**kwargs,
) -> nn.Module:
"""Load model with KTMoEWrapper backend."""
from .arch import get_moe_arch_config, move_non_experts_to_gpu, get_expert_device, KTAMXNotAvailableError, KTAMXConfigError
if kt_plugin is None:
if model_args is None:
raise KTAMXConfigError("Either kt_plugin or model_args must be provided to load_kt_model().")
kt_plugin = _build_kt_plugin_from_args(model_args, finetuning_args)
if model_name_or_path is None and model_args is not None:
model_name_or_path = getattr(model_args, "model_name_or_path", None)
if model_name_or_path is None:
raise KTAMXConfigError("model_name_or_path is required to load_kt_model().")
if trust_remote_code is None and model_args is not None:
trust_remote_code = getattr(model_args, "trust_remote_code", None)
if token is None and model_args is not None:
token = getattr(model_args, "hf_hub_token", None)
cache_dir = getattr(model_args, "cache_dir", None) if model_args is not None else None
revision = getattr(model_args, "revision", None) if model_args is not None else None
_ = get_moe_arch_config(config)
logger.info("Loading model with KTMoEWrapper backend")
from transformers import AutoModelForCausalLM
from transformers.integrations.kt import set_kt_config, unset_kt_config
loading_kwargs = get_kt_loading_kwargs(
config, kt_plugin, torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code, token=token,
)
if model_args is not None:
for key in ("cache_dir", "revision"):
value = getattr(model_args, key, None)
if value is not None:
loading_kwargs[key] = value
loading_kwargs.update(kwargs)
cfg = _get_kt_config(kt_plugin)
if getattr(cfg, "skip_expert_loading", None) is None:
checkpoint_files, sharded_metadata = _resolve_checkpoint_files(
model_name_or_path=model_name_or_path,
cache_dir=cache_dir, revision=revision,
token=token, trust_remote_code=trust_remote_code,
)
if checkpoint_files and all(f.endswith(".safetensors") for f in checkpoint_files):
if getattr(cfg, "weight_path", None) is None:
cfg.skip_expert_loading = True
else:
cfg.skip_expert_loading = False
cfg.checkpoint_files = checkpoint_files
cfg.sharded_metadata = sharded_metadata
else:
cfg.skip_expert_loading = False
set_kt_config(kt_plugin)
try:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **loading_kwargs)
finally:
unset_kt_config()
moe_config = get_moe_arch_config(config)
move_non_experts_to_gpu(model, moe_config, device="cuda:0")
existing_wrappers = getattr(model, "_kt_wrappers", None)
if existing_wrappers:
logger.info(f"MoE layers already wrapped ({len(existing_wrappers)} layers), skipping re-wrap")
wrappers = existing_wrappers
else:
wrappers = wrap_moe_layers_with_kt_wrapper(model, kt_plugin)
model._kt_wrappers = wrappers
model._kt_tp_enabled = bool(getattr(cfg, "tp_enabled", False))
model._kt_use_lora_experts = bool(getattr(cfg, "use_lora_experts", False))
logger.info("Model loaded with KTMoEWrapper backend successfully")
return model