feat(sft): AMX MoE SFT backend with LoRA support (#1936)
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled

* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
This commit is contained in:
mrhaoxx 2026-04-22 11:27:01 +08:00 committed by GitHub
parent 22e9915ec9
commit 9544a8960d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 22866 additions and 937 deletions

View file

@ -0,0 +1,83 @@
# SFT (Supervised Fine-Tuning) submodule for kt-kernel
# SPDX-License-Identifier: Apache-2.0
"""
SFT training support for KT-Kernel MoE.
This submodule adds training capabilities (forward/backward, LoRA, autograd,
distributed) on top of the inference-only kt_kernel base package.
Additional dependencies beyond base kt_kernel: torch.nn, torch.distributed, peft (optional).
"""
from .config import KTConfig
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
from .amx import AMXSFTMoEWrapper
from .arch import (
MOEArchConfig, get_moe_arch_config, get_moe_module, move_non_experts_to_gpu, get_expert_device,
KTAMXError, KTAMXNotAvailableError, KTAMXModelNotSupportedError, KTAMXConfigError,
)
from .autograd import KTMoEFunction
from .layer import KTMoELayerWrapper
from .weights import (
extract_moe_weights,
load_experts_from_checkpoint_files,
load_experts_from_kt_weight_path,
INT8ExpertWeights,
)
from .lora import (
kt_adapt_peft_lora,
get_kt_lora_params,
update_kt_lora_pointers,
sync_kt_lora_gradients,
save_lora_experts_to_adapter,
save_kt_moe_to_adapter,
load_lora_experts_from_adapter,
load_kt_moe_from_adapter,
LoRAExpertMLP,
LoRAExperts,
)
from .wrapper import (
wrap_moe_layers_with_kt_wrapper,
build_kt_device_map,
build_kt_device_map_simplified,
get_kt_loading_kwargs,
load_kt_model,
)
__all__ = [
"KTConfig",
"BaseSFTMoEWrapper",
"KExpertsSFTBuffer",
"AMXSFTMoEWrapper",
"MOEArchConfig",
"get_moe_arch_config",
"get_moe_module",
"move_non_experts_to_gpu",
"get_expert_device",
"KTAMXError",
"KTAMXNotAvailableError",
"KTAMXModelNotSupportedError",
"KTAMXConfigError",
"KTMoEFunction",
"KTMoELayerWrapper",
"extract_moe_weights",
"load_experts_from_checkpoint_files",
"load_experts_from_kt_weight_path",
"INT8ExpertWeights",
"kt_adapt_peft_lora",
"get_kt_lora_params",
"update_kt_lora_pointers",
"sync_kt_lora_gradients",
"save_lora_experts_to_adapter",
"save_kt_moe_to_adapter",
"load_lora_experts_from_adapter",
"load_kt_moe_from_adapter",
"LoRAExpertMLP",
"LoRAExperts",
"wrap_moe_layers_with_kt_wrapper",
"build_kt_device_map",
"build_kt_device_map_simplified",
"get_kt_loading_kwargs",
"load_kt_model",
]

434
kt-kernel/python/sft/amx.py Normal file
View file

@ -0,0 +1,434 @@
# AMX SFT MoE Wrapper implementation
# SPDX-License-Identifier: Apache-2.0
"""
AMX-based SFT MoE Wrapper. Forward/backward buffer management is in base class;
this file handles weight loading, LoRA init, and C++ task construction.
"""
from __future__ import annotations
import ctypes
import os
import glob as _glob
import torch
from typing import Optional, List
from kt_kernel_ext.moe import MOESFTConfig
from ..utils.loader import BF16SafeTensorLoader, SafeTensorLoader
try:
from kt_kernel_ext.moe import (
AMXBF16_SFT_MOE,
AMXInt8_SFT_MOE,
AMXInt4_SFT_MOE,
AMXBF16_SFT_MOE_SkipLoRA,
AMXInt8_SFT_MOE_SkipLoRA,
AMXInt4_SFT_MOE_SkipLoRA,
)
_HAS_AMX_SFT_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SFT_SUPPORT = False
AMXBF16_SFT_MOE = None
AMXInt8_SFT_MOE = None
AMXInt4_SFT_MOE = None
AMXBF16_SFT_MOE_SkipLoRA = None
AMXInt8_SFT_MOE_SkipLoRA = None
AMXInt4_SFT_MOE_SkipLoRA = None
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
# Mapping from method string to C++ SFT MOE class
_SFT_METHOD_TO_CLASS = {
"AMXBF16_SFT": AMXBF16_SFT_MOE,
"AMXINT8_SFT": AMXInt8_SFT_MOE,
"AMXINT4_SFT": AMXInt4_SFT_MOE,
"AMXBF16_SFT_SkipLoRA": AMXBF16_SFT_MOE_SkipLoRA,
"AMXINT8_SFT_SkipLoRA": AMXInt8_SFT_MOE_SkipLoRA,
"AMXINT4_SFT_SkipLoRA": AMXInt4_SFT_MOE_SkipLoRA,
}
class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
"""
AMX-based SFT MoE wrapper.
Supports BF16, INT8, INT4, and SkipLoRA variants.
Forward/backward buffer management is in BaseSFTMoEWrapper;
this class implements weight loading and C++ task construction.
"""
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
lora_rank: int = 16,
lora_alpha: float = 32.0,
max_cache_depth: int = 1,
method: str = "AMXBF16_SFT",
group_size: int = 128,
zero_point: bool = True,
):
if not _HAS_AMX_SFT_SUPPORT:
raise RuntimeError(
"AMX SFT backend not available. kt_kernel_ext was not compiled with AMX SFT support.\n"
"Please recompile with AMX SFT enabled."
)
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
max_cache_depth=max_cache_depth,
)
self.method = method
self._is_skip_lora = "SkipLoRA" in method
self.group_size = group_size
self.zero_point = zero_point
if method not in _SFT_METHOD_TO_CLASS:
raise ValueError(f"Unknown SFT method: {method}. Supported: {list(_SFT_METHOD_TO_CLASS.keys())}")
moe_class = _SFT_METHOD_TO_CLASS[method]
if moe_class is None:
raise RuntimeError(f"AMX SFT method '{method}' not available in current build.")
self.gate_proj: Optional[torch.Tensor] = None
self.up_proj: Optional[torch.Tensor] = None
self.down_proj: Optional[torch.Tensor] = None
self._moe_class = moe_class
# ========== Template method: C++ task construction ==========
def _make_forward_task(self, buffer: KExpertsSFTBuffer, save_for_backward: bool):
return self.moe.forward_sft_task(
buffer.bsz_tensor.data_ptr(),
self.num_experts_per_tok,
buffer.expert_ids_cpu.data_ptr(),
buffer.weights_cpu.data_ptr(),
buffer.input_cpu.data_ptr(),
buffer.output_cpu.data_ptr(),
save_for_backward,
)
def _make_backward_task(self, buffer: KExpertsSFTBuffer):
if self._is_skip_lora:
return self.moe.backward_task(
buffer.grad_output_cpu.data_ptr(),
buffer.grad_input_cpu.data_ptr(),
0, 0, 0, 0, 0, 0,
buffer.grad_weights.data_ptr(),
)
return self.moe.backward_task(
buffer.grad_output_cpu.data_ptr(),
buffer.grad_input_cpu.data_ptr(),
self.grad_gate_lora_a.data_ptr(),
self.grad_gate_lora_b.data_ptr(),
self.grad_up_lora_a.data_ptr(),
self.grad_up_lora_b.data_ptr(),
self.grad_down_lora_a.data_ptr(),
self.grad_down_lora_b.data_ptr(),
buffer.grad_weights.data_ptr(),
)
# ========== Weight loading ==========
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None:
if self._weights_loaded:
return
if self.gate_proj is None and not getattr(self, "_use_projs_path", False):
self._load_base_weights_from_file()
config = MOESFTConfig()
config.expert_num = self.num_experts
config.num_experts_per_tok = self.num_experts_per_tok
config.hidden_size = self.hidden_size
config.intermediate_size = self.moe_intermediate_size
config.lora_rank = self.lora_rank
config.lora_alpha = self.lora_alpha
config.max_cache_depth = self.max_cache_depth
config.max_len = self.chunked_prefill_size
config.layer_idx = self.layer_idx
config.share_backward_bb = getattr(self, "share_backward_bb", False)
config.share_cache_pool = getattr(self, "share_cache_pool", False)
if getattr(self, "_use_kt_direct_load", False):
config.load = True
config.path = self.weight_path
elif getattr(self, "_use_projs_path", False):
config.gate_projs = self._gate_projs_ptrs
config.up_projs = self._up_projs_ptrs
config.down_projs = self._down_projs_ptrs
config.gate_scales = self._gate_scale_ptrs
config.up_scales = self._up_scale_ptrs
config.down_scales = self._down_scale_ptrs
if getattr(self, "_bf16_gate_proj", None) is not None:
config.gate_proj = self._bf16_gate_proj.data_ptr()
config.up_proj = self._bf16_up_proj.data_ptr()
config.down_proj = self._bf16_down_proj.data_ptr()
if getattr(self, "_has_bwd_projs", False):
config.gate_bwd_projs = self._gate_bwd_projs_ptrs
config.up_bwd_projs = self._up_bwd_projs_ptrs
config.down_bwd_projs = self._down_bwd_projs_ptrs
config.gate_bwd_scales = self._gate_bwd_scale_ptrs
config.up_bwd_scales = self._up_bwd_scale_ptrs
config.down_bwd_scales = self._down_bwd_scale_ptrs
else:
config.gate_proj = self.gate_proj.data_ptr()
config.up_proj = self.up_proj.data_ptr()
config.down_proj = self.down_proj.data_ptr()
if self._lora_initialized:
config.gate_lora_a = self.gate_lora_a.data_ptr()
config.gate_lora_b = self.gate_lora_b.data_ptr()
config.up_lora_a = self.up_lora_a.data_ptr()
config.up_lora_b = self.up_lora_b.data_ptr()
config.down_lora_a = self.down_lora_a.data_ptr()
config.down_lora_b = self.down_lora_b.data_ptr()
config.pool = self.cpu_infer.backend_
if self.method in ("AMXINT4_KGroup_SFT", "AMXINT4_1KGroup_SFT"):
config.quant_config.group_size = self.group_size
config.quant_config.zero_point = self.zero_point
self.moe = self._moe_class(config)
self.cpu_infer.submit(self.moe.load_weights_task())
self.cpu_infer.sync()
self.cpu_infer.submit(self.moe.warm_up_task())
self.cpu_infer.sync()
# Release Python-side weight tensors (C++ copied them)
self.gate_proj = None
self.up_proj = None
self.down_proj = None
if getattr(self, "_bf16_gate_proj", None) is not None:
self._bf16_gate_proj = None
self._bf16_up_proj = None
self._bf16_down_proj = None
if getattr(self, "_use_projs_path", False):
for attr in [
"_gate_weights_per_numa", "_up_weights_per_numa", "_down_weights_per_numa",
"_gate_scales_per_numa", "_up_scales_per_numa", "_down_scales_per_numa",
"_gate_projs_ptrs", "_up_projs_ptrs", "_down_projs_ptrs",
"_gate_scale_ptrs", "_up_scale_ptrs", "_down_scale_ptrs",
]:
setattr(self, attr, None)
if getattr(self, "_has_bwd_projs", False):
for attr in [
"_gate_bwd_weights_per_numa", "_up_bwd_weights_per_numa", "_down_bwd_weights_per_numa",
"_gate_bwd_scales_per_numa", "_up_bwd_scales_per_numa", "_down_bwd_scales_per_numa",
"_gate_bwd_projs_ptrs", "_up_bwd_projs_ptrs", "_down_bwd_projs_ptrs",
"_gate_bwd_scale_ptrs", "_up_bwd_scale_ptrs", "_down_bwd_scale_ptrs",
]:
setattr(self, attr, None)
self._weights_loaded = True
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
) -> None:
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()
self.load_weights(physical_to_logical_map_cpu)
del gate_proj, up_proj, down_proj
def _load_base_weights_from_file(self) -> None:
if not hasattr(self, "weight_path") or self.weight_path is None:
raise RuntimeError(
"weight_path not set. Cannot load weights from file. "
"Either set weight_path or call load_weights_from_tensors() instead."
)
kt_layer_dir = os.path.join(self.weight_path, f"_layer_{self.layer_idx}")
if os.path.isdir(kt_layer_dir):
kt_files = _glob.glob(os.path.join(kt_layer_dir, "_numa_0", "*.kt"))
if kt_files:
self._use_kt_direct_load = True
return
if "BF16" in self.method:
loader = BF16SafeTensorLoader(self.weight_path)
base_key = f"model.layers.{self.layer_idx}"
else:
loader = SafeTensorLoader(self.weight_path)
base_key = f"blk.{self.layer_idx}"
experts_data = loader.load_experts(base_key, device="cpu")
gate_weights: List[torch.Tensor] = experts_data["gate"]
up_weights: List[torch.Tensor] = experts_data["up"]
down_weights: List[torch.Tensor] = experts_data["down"]
if "BF16" in self.method:
self.gate_proj = torch.stack(gate_weights, dim=0).contiguous()
self.up_proj = torch.stack(up_weights, dim=0).contiguous()
self.down_proj = torch.stack(down_weights, dim=0).contiguous()
else:
def _make_ptrs(arrays_per_numa):
return [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in arrays_per_numa
]
self._gate_weights_per_numa = gate_weights
self._up_weights_per_numa = up_weights
self._down_weights_per_numa = down_weights
self._gate_scales_per_numa = experts_data["gate_scale"]
self._up_scales_per_numa = experts_data["up_scale"]
self._down_scales_per_numa = experts_data["down_scale"]
self._gate_projs_ptrs = _make_ptrs(gate_weights)
self._up_projs_ptrs = _make_ptrs(up_weights)
self._down_projs_ptrs = _make_ptrs(down_weights)
self._gate_scale_ptrs = _make_ptrs(experts_data["gate_scale"])
self._up_scale_ptrs = _make_ptrs(experts_data["up_scale"])
self._down_scale_ptrs = _make_ptrs(experts_data["down_scale"])
if "gate_bwd" in experts_data:
self._gate_bwd_weights_per_numa = experts_data["gate_bwd"]
self._up_bwd_weights_per_numa = experts_data["up_bwd"]
self._down_bwd_weights_per_numa = experts_data["down_bwd"]
self._gate_bwd_scales_per_numa = experts_data["gate_bwd_scale"]
self._up_bwd_scales_per_numa = experts_data["up_bwd_scale"]
self._down_bwd_scales_per_numa = experts_data["down_bwd_scale"]
self._gate_bwd_projs_ptrs = _make_ptrs(experts_data["gate_bwd"])
self._up_bwd_projs_ptrs = _make_ptrs(experts_data["up_bwd"])
self._down_bwd_projs_ptrs = _make_ptrs(experts_data["down_bwd"])
self._gate_bwd_scale_ptrs = _make_ptrs(experts_data["gate_bwd_scale"])
self._up_bwd_scale_ptrs = _make_ptrs(experts_data["up_bwd_scale"])
self._down_bwd_scale_ptrs = _make_ptrs(experts_data["down_bwd_scale"])
self._has_bwd_projs = True
else:
self._has_bwd_projs = False
self.gate_proj = None
self.up_proj = None
self.down_proj = None
self._use_projs_path = True
loader.close_all_handles()
# ========== LoRA ==========
def init_lora_weights(
self,
gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor,
up_lora_a: torch.Tensor, up_lora_b: torch.Tensor,
down_lora_a: torch.Tensor, down_lora_b: torch.Tensor,
grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor,
grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor,
grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor,
) -> None:
expected_shapes = {
"gate_lora_a": (self.num_experts, self.lora_rank, self.hidden_size),
"gate_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank),
"up_lora_a": (self.num_experts, self.lora_rank, self.hidden_size),
"up_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank),
"down_lora_a": (self.num_experts, self.lora_rank, self.moe_intermediate_size),
"down_lora_b": (self.num_experts, self.hidden_size, self.lora_rank),
}
provided = {
"gate_lora_a": gate_lora_a, "gate_lora_b": gate_lora_b,
"up_lora_a": up_lora_a, "up_lora_b": up_lora_b,
"down_lora_a": down_lora_a, "down_lora_b": down_lora_b,
}
for name, tensor in provided.items():
expected = expected_shapes[name]
if tensor.shape != expected:
raise ValueError(f"{name} shape mismatch: expected {expected}, got {tuple(tensor.shape)}")
self.gate_lora_a = gate_lora_a.contiguous()
self.gate_lora_b = gate_lora_b.contiguous()
self.up_lora_a = up_lora_a.contiguous()
self.up_lora_b = up_lora_b.contiguous()
self.down_lora_a = down_lora_a.contiguous()
self.down_lora_b = down_lora_b.contiguous()
self.grad_gate_lora_a = grad_gate_lora_a.contiguous()
self.grad_gate_lora_b = grad_gate_lora_b.contiguous()
self.grad_up_lora_a = grad_up_lora_a.contiguous()
self.grad_up_lora_b = grad_up_lora_b.contiguous()
self.grad_down_lora_a = grad_down_lora_a.contiguous()
self.grad_down_lora_b = grad_down_lora_b.contiguous()
self._lora_initialized = True
if self._weights_loaded and self.moe is not None:
self.update_lora_weights()
def update_lora_weights(self) -> None:
if not self._weights_loaded:
raise RuntimeError("Weights not loaded. Call load_weights() first.")
if self._is_skip_lora:
return
if not self._lora_initialized:
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
self.cpu_infer.submit(
self.moe.update_lora_weights_task(
self.gate_lora_a.data_ptr(),
self.gate_lora_b.data_ptr(),
self.up_lora_a.data_ptr(),
self.up_lora_b.data_ptr(),
self.down_lora_a.data_ptr(),
self.down_lora_b.data_ptr(),
)
)
self.cpu_infer.sync()
def save_backward_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map: torch.Tensor,
output_path: str,
) -> None:
if not self._weights_loaded:
raise RuntimeError("Weights not loaded. Call load_weights() first.")
gate_proj = gate_proj.contiguous()
up_proj = up_proj.contiguous()
down_proj = down_proj.contiguous()
self.moe.prepare_and_save_bwd(
gate_proj.data_ptr(),
up_proj.data_ptr(),
down_proj.data_ptr(),
output_path,
)

View file

@ -0,0 +1,282 @@
# MoE architecture configuration and model utilities
# SPDX-License-Identifier: Apache-2.0
"""
MoE architecture detection and model navigation utilities.
This is a leaf module no imports from other sft/ submodules.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
# =============================================================================
# Exceptions
# =============================================================================
class KTAMXError(Exception):
"""Base exception for KT AMX errors."""
class KTAMXNotAvailableError(KTAMXError):
"""kt_kernel not installed or AMX not supported."""
class KTAMXModelNotSupportedError(KTAMXError):
"""Model architecture not supported."""
class KTAMXConfigError(KTAMXError):
"""Configuration error."""
# =============================================================================
# MoE Configuration
# =============================================================================
@dataclass
class MOEArchConfig:
"""MoE architecture configuration for different model types."""
moe_layer_attr: str
router_attr: str
experts_attr: str
weight_names: tuple[str, str, str]
expert_num: int
intermediate_size: int
num_experts_per_tok: int
has_shared_experts: bool = False
router_type: str = "linear"
def get_moe_arch_config(config) -> MOEArchConfig:
"""
Get MoE architecture configuration based on model type.
Args:
config: HuggingFace model configuration
Returns:
MOEArchConfig for the model
Raises:
KTAMXModelNotSupportedError: If model architecture is not supported
"""
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
if "DeepseekV2" in arch:
return MOEArchConfig(
moe_layer_attr="mlp",
router_attr="gate",
experts_attr="experts",
weight_names=("gate_proj", "up_proj", "down_proj"),
expert_num=config.n_routed_experts,
intermediate_size=config.moe_intermediate_size,
num_experts_per_tok=config.num_experts_per_tok,
has_shared_experts=getattr(config, "n_shared_experts", 0) > 0,
router_type="deepseek_gate",
)
if "DeepseekV3" in arch:
return MOEArchConfig(
moe_layer_attr="mlp",
router_attr="gate",
experts_attr="experts",
weight_names=("gate_proj", "up_proj", "down_proj"),
expert_num=config.n_routed_experts,
intermediate_size=config.moe_intermediate_size,
num_experts_per_tok=config.num_experts_per_tok,
has_shared_experts=getattr(config, "n_shared_experts", 0) > 0,
router_type="deepseek_gate",
)
if "Qwen2Moe" in arch or "Qwen3Moe" in arch or "Qwen3_5Moe" in arch:
cfg = getattr(config, "text_config", config)
return MOEArchConfig(
moe_layer_attr="mlp",
router_attr="gate",
experts_attr="experts",
weight_names=("gate_proj", "up_proj", "down_proj"),
expert_num=cfg.num_experts,
intermediate_size=cfg.moe_intermediate_size,
num_experts_per_tok=cfg.num_experts_per_tok,
has_shared_experts=getattr(cfg, "shared_expert_intermediate_size", 0) > 0,
)
if "Mixtral" in arch:
return MOEArchConfig(
moe_layer_attr="block_sparse_moe",
router_attr="gate",
experts_attr="experts",
weight_names=("w1", "w3", "w2"),
expert_num=config.num_local_experts,
intermediate_size=config.intermediate_size,
num_experts_per_tok=config.num_experts_per_tok,
has_shared_experts=False,
)
raise KTAMXModelNotSupportedError(
f"Model architecture {arch} not supported for KT AMX. "
"Supported architectures: DeepseekV2, DeepseekV3, Qwen2Moe, Qwen3Moe, Qwen3_5Moe, Mixtral"
)
def get_moe_module(layer: nn.Module, moe_config: MOEArchConfig) -> nn.Module | None:
"""Get MoE module from transformer layer."""
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
if moe_module is None:
return None
if not hasattr(moe_module, moe_config.experts_attr):
return None
return moe_module
def detect_fused_experts(experts: nn.Module) -> bool:
"""Detect if experts module uses the transformers v5 fused format.
Fused format: a single Module with ``gate_up_proj`` [E, 2I, H] and
``down_proj`` [E, H, I] 3-D tensors instead of a ModuleList of Linear experts.
"""
if experts is None:
return False
gate_up = getattr(experts, "gate_up_proj", None)
down = getattr(experts, "down_proj", None)
if isinstance(gate_up, torch.Tensor) and isinstance(down, torch.Tensor):
return gate_up.dim() == 3 and down.dim() == 3
return False
def _get_layers_prefix(config) -> str:
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
if "Qwen3_5Moe" in arch:
return "model.language_model.layers"
return "model.layers"
def _get_model_container_and_layers(model: nn.Module, *, purpose: str) -> tuple[nn.Module, any]:
"""
Resolve the transformer layer container for KT integration.
KT expects the transformer block stack to be accessible as `<container>.layers`.
Handles PEFT PeftModel, TRL value-head models, DDP wrappers.
"""
to_visit: list[nn.Module] = [model]
visited: set[int] = set()
visited_types: list[str] = []
while to_visit:
current = to_visit.pop(0)
if id(current) in visited:
continue
visited.add(id(current))
visited_types.append(type(current).__name__)
layers = getattr(current, "layers", None)
if layers is not None and isinstance(layers, (list, tuple, nn.ModuleList)):
return current, layers
for attr in ("model", "base_model", "pretrained_model", "module", "language_model"):
child = getattr(current, attr, None)
if isinstance(child, nn.Module) and child is not current:
to_visit.append(child)
get_base_model = getattr(current, "get_base_model", None)
if callable(get_base_model):
try:
base = get_base_model()
except Exception:
base = None
if isinstance(base, nn.Module) and base is not current:
to_visit.append(base)
visited_preview = ", ".join(visited_types[:6])
if len(visited_types) > 6:
visited_preview += ", ..."
raise KTAMXConfigError(
f"Model does not expose a .model.layers or .layers attribute for KT {purpose}. "
"Tried unwrapping via model/base_model/pretrained_model/module/get_base_model; "
f"visited: {visited_preview}"
)
def move_non_experts_to_gpu(
model: nn.Module,
moe_config: MOEArchConfig | None = None,
device: str = "cuda:0",
) -> None:
"""Move non-expert parameters to GPU after loading (experts stay on CPU)."""
if moe_config is None:
config = getattr(model, "config", None)
if config is None:
raise KTAMXConfigError("Model config is required to infer MoE architecture.")
moe_config = get_moe_arch_config(config)
container, layers = _get_model_container_and_layers(model, purpose="placement")
if hasattr(container, "embed_tokens"):
container.embed_tokens.to(device)
if hasattr(container, "norm"):
container.norm.to(device)
if hasattr(model, "lm_head"):
model.lm_head.to(device)
for layer in layers:
if hasattr(layer, "self_attn"):
layer.self_attn.to(device)
if hasattr(layer, "input_layernorm"):
layer.input_layernorm.to(device)
if hasattr(layer, "post_attention_layernorm"):
layer.post_attention_layernorm.to(device)
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
if moe_module is None or not hasattr(moe_module, moe_config.experts_attr):
if hasattr(layer, "mlp"):
layer.mlp.to(device)
continue
router = getattr(moe_module, moe_config.router_attr, None)
if router is not None:
router.to(device)
if hasattr(moe_module, "shared_experts") and moe_module.shared_experts is not None:
moe_module.shared_experts.to(device)
logger.info(f"Moved non-expert parameters to {device}")
def get_expert_device(model: nn.Module, moe_config: MOEArchConfig | None = None) -> str:
"""Get the device type of MoE experts."""
if moe_config is None:
config = getattr(model, "config", None)
if config is None:
return "unknown"
moe_config = get_moe_arch_config(config)
try:
_, layers = _get_model_container_and_layers(model, purpose="expert device probing")
except KTAMXConfigError:
return "unknown"
for layer in layers:
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
if moe_module is None:
continue
experts = getattr(moe_module, moe_config.experts_attr, None)
if not experts:
continue
first_expert = experts[0]
gate_name = moe_config.weight_names[0]
gate_proj = getattr(first_expert, gate_name, None)
if gate_proj is not None:
return str(gate_proj.weight.device.type)
return "unknown"

View file

@ -0,0 +1,254 @@
# Autograd function for KT MoE SFT training
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import os
from typing import Any
import torch
from .dist_utils import (
_all_gather_qlens,
_qlen_offsets,
_dist_gather_varlen_to_rank0,
_dist_scatter_varlen_from_rank0,
)
_KT_SFT_DEBUG = os.environ.get("KT_SFT_DEBUG", "0") == "1"
logger = logging.getLogger(__name__)
class KTMoEFunction(torch.autograd.Function):
"""Unified autograd function for KTMoE forward/backward."""
@staticmethod
def forward(
ctx,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
wrapper: Any,
lora_ref: torch.Tensor,
hidden_size: int,
num_experts_per_tok: int,
layer_idx: int,
training: bool,
train_lora: bool,
all_qlens: list[int] | tuple[int, ...] | None,
) -> torch.Tensor:
if _KT_SFT_DEBUG:
logging.debug(
"KTMoEFunction.forward: layer=%d training=%s train_lora=%s",
layer_idx, training, train_lora,
)
original_device = hidden_states.device
original_dtype = hidden_states.dtype
batch_size, seq_len, _ = hidden_states.shape
qlen = batch_size * seq_len
import torch.distributed as dist
dist_on = dist.is_initialized() and dist.get_world_size() > 1
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist_on else 1
ctx.use_broadcast = wrapper is None
# ---- Sync CPU expert result and distribute ----
if dist_on:
if all_qlens is None:
all_qlens_list = _all_gather_qlens(qlen, original_device, world_size)
else:
all_qlens_list = [int(q) for q in all_qlens]
if len(all_qlens_list) != world_size:
raise RuntimeError(
f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}"
)
if int(all_qlens_list[rank]) != qlen:
raise RuntimeError(
f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}"
)
total_qlen = sum(all_qlens_list)
# Rank 0: sync CPU result and split by real lengths
if rank == 0:
cpu_output = wrapper.sync_forward(output_device=original_device)
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, hidden_size)
offsets = _qlen_offsets(all_qlens_list)
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
else:
scatter_list = None
output_flat = _dist_scatter_varlen_from_rank0(
rank0_chunks=scatter_list,
all_qlens=all_qlens_list,
rank=rank,
world_size=world_size,
feature_shape=(hidden_size,),
device=original_device,
dtype=original_dtype,
)
output = output_flat.view(batch_size, seq_len, hidden_size)
del output_flat
elif wrapper is not None:
# Single-GPU: sync directly
cpu_output = wrapper.sync_forward(output_device=original_device)
output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype)
else:
# Broadcast-only rank (no wrapper)
output = torch.empty(
batch_size, seq_len, hidden_size, device=original_device, dtype=original_dtype
)
ctx.wrapper = wrapper
ctx.hidden_size = hidden_size
ctx.qlen = qlen
ctx.batch_size = batch_size
ctx.seq_len = seq_len
ctx.original_device = original_device
ctx.original_dtype = original_dtype
ctx.weights_shape = topk_weights.shape
ctx.weights_dtype = topk_weights.dtype
ctx.weights_device = topk_weights.device
ctx.dist_on = dist_on
ctx.world_size = world_size
ctx.all_qlens = all_qlens_list if dist_on else None
ctx.num_experts_per_tok = num_experts_per_tok
ctx.layer_idx = layer_idx
# Save a sentinel tensor so non-reentrant checkpoint's saved_tensors
# hooks can intercept it. When backward accesses ctx.saved_tensors,
# the checkpoint unpack hook triggers a full recompute of the decoder
# layer — which re-runs the MoE forward with save_for_backward=True,
# populating the C++ cache BEFORE this backward proceeds.
# Without this, MoE backward runs before the recompute (MoE comes
# after attention in forward order → its backward runs first), and
# the C++ cache is empty when first-forward cache-skip is active.
ctx.save_for_backward(hidden_states.new_empty(()))
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Wait for any in-flight async repack before recompute forward uses the pool
if getattr(ctx.wrapper, 'share_backward_bb', False):
ctx.wrapper.wait_backward_repack()
# Access saved_tensors FIRST — under non-reentrant checkpoint this
# triggers the unpack hook which runs a full decoder-layer recompute,
# populating the C++ cache before we call wrapper.backward().
_ = ctx.saved_tensors
qlen = ctx.qlen
hidden_size = ctx.hidden_size
batch_size = ctx.batch_size
seq_len = ctx.seq_len
dist_on = ctx.dist_on
world_size = ctx.world_size
num_experts_per_tok = ctx.num_experts_per_tok
import torch.distributed as dist
rank = dist.get_rank() if dist.is_initialized() else 0
if _KT_SFT_DEBUG:
logging.debug(
"KTMoEFunction.backward: layer=%d dist_on=%s qlen=%d",
getattr(ctx, "layer_idx", -1), dist_on, qlen,
)
if dist_on:
all_qlens = getattr(ctx, "all_qlens", None)
if all_qlens is None or len(all_qlens) != world_size:
all_qlens = _all_gather_qlens(qlen, ctx.original_device, world_size)
else:
all_qlens = [int(q) for q in all_qlens]
if int(all_qlens[rank]) != qlen:
raise RuntimeError(
f"Backward qlen mismatch on rank {rank}: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}"
)
grad_out_flat = grad_output.view(qlen, hidden_size).contiguous()
gathered_go = _dist_gather_varlen_to_rank0(
grad_out_flat,
all_qlens=all_qlens,
rank=rank,
world_size=world_size,
)
if rank == 0:
all_go = torch.cat(gathered_go, dim=0)
total_qlen = int(all_go.shape[0])
backward_out = ctx.wrapper.backward(
all_go,
output_device=ctx.original_device,
)
if isinstance(backward_out, tuple) and len(backward_out) == 2:
all_grad_input, all_grad_weights = backward_out
elif isinstance(backward_out, tuple) and len(backward_out) == 3:
all_grad_input, _, all_grad_weights = backward_out
else:
raise ValueError("KTMoEWrapper.backward returned unexpected format.")
all_grad_input = all_grad_input.to(dtype=ctx.original_dtype).view(total_qlen, hidden_size)
all_grad_weights = all_grad_weights.to(dtype=torch.bfloat16).view(total_qlen, num_experts_per_tok)
offsets = _qlen_offsets(all_qlens)
scatter_gi = [all_grad_input[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
scatter_gw = [all_grad_weights[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
else:
scatter_gi = None
scatter_gw = None
grad_input_flat = _dist_scatter_varlen_from_rank0(
rank0_chunks=scatter_gi,
all_qlens=all_qlens,
rank=rank,
world_size=world_size,
feature_shape=(hidden_size,),
device=ctx.original_device,
dtype=ctx.original_dtype,
)
grad_weights_flat = _dist_scatter_varlen_from_rank0(
rank0_chunks=scatter_gw,
all_qlens=all_qlens,
rank=rank,
world_size=world_size,
feature_shape=(num_experts_per_tok,),
device=ctx.weights_device,
dtype=torch.bfloat16,
)
grad_input = grad_input_flat.view(batch_size, seq_len, hidden_size)
grad_weights = grad_weights_flat.view(ctx.weights_shape).to(dtype=ctx.weights_dtype)
elif not ctx.use_broadcast:
# ---- Single-GPU path ----
grad_output_flat = grad_output.view(qlen, hidden_size)
backward_out = ctx.wrapper.backward(
grad_output_flat,
output_device=ctx.original_device,
)
ctx.wrapper._kt_has_cached_forward = False
if isinstance(backward_out, tuple) and len(backward_out) == 2:
grad_input, grad_weights = backward_out
elif isinstance(backward_out, tuple) and len(backward_out) == 3:
grad_input, _, grad_weights = backward_out
else:
raise ValueError("KTMoEWrapper.backward returned unexpected format.")
grad_input = grad_input.view(batch_size, seq_len, hidden_size).to(dtype=ctx.original_dtype)
grad_weights = grad_weights.to(dtype=torch.bfloat16)
else:
# No wrapper, no dist — shouldn't happen in normal flow
grad_input = torch.zeros(batch_size, seq_len, hidden_size, device=ctx.original_device, dtype=ctx.original_dtype)
grad_weights = torch.zeros(ctx.weights_shape, device=ctx.weights_device, dtype=ctx.weights_dtype)
# Trigger async repack for next MoE layer in backward order
next_bwd = getattr(ctx.wrapper, '_next_backward_wrapper', None)
if next_bwd is not None and getattr(next_bwd, 'share_backward_bb', False):
next_bwd.submit_backward_repack()
return grad_input, None, grad_weights, None, None, None, None, None, None, None, None

View file

@ -0,0 +1,402 @@
# Base classes for SFT MoE operations
# SPDX-License-Identifier: Apache-2.0
"""
SFT (Supervised Fine-Tuning) MoE base classes and buffer management.
Provides:
- KExpertsSFTBuffer: Grow-only shared buffer for forward/backward passes
- BaseSFTMoEWrapper: Abstract base with concrete buffer management (template method pattern)
"""
from __future__ import annotations
import torch
from typing import Optional, Tuple
from abc import ABC, abstractmethod
from ..experts_base import _MoEBase
class KExpertsSFTBuffer:
"""
CPU buffer management for SFT expert computation.
Single grow-only buffer (never shrinks). Callers must use [:qlen] slicing
since the buffer may be larger than the current batch.
"""
_shared_buffer: Optional["KExpertsSFTBuffer"] = None
def __init__(
self,
qlen: int,
hidden_size: int,
moe_intermediate_size: int,
num_experts: int,
num_experts_per_tok: int,
lora_rank: int,
dtype: torch.dtype = torch.bfloat16,
):
self.qlen = qlen
self.hidden_size = hidden_size
self.moe_intermediate_size = moe_intermediate_size
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.lora_rank = lora_rank
self.dtype = dtype
pin_memory = False
# Forward buffers
self.input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
self.expert_ids_cpu = torch.empty(
(qlen, num_experts_per_tok), dtype=torch.int64, device="cpu", pin_memory=pin_memory
)
self.weights_cpu = torch.empty(
(qlen, num_experts_per_tok), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
# Backward buffers
self.grad_output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
self.grad_input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
self.grad_weights = torch.empty((qlen, num_experts_per_tok), dtype=torch.float32, device="cpu")
# Batch size tensor for C++ interface
self.bsz_tensor = torch.tensor([qlen], dtype=torch.int32, device="cpu")
@classmethod
def get_buffer(
cls,
qlen: int,
hidden_size: int,
moe_intermediate_size: int,
num_experts: int,
num_experts_per_tok: int,
lora_rank: int,
dtype: torch.dtype = torch.bfloat16,
) -> "KExpertsSFTBuffer":
"""Get or grow the single shared buffer. Only reallocates when qlen exceeds capacity."""
buf = cls._shared_buffer
if buf is not None and qlen <= buf.qlen:
return buf
cls._shared_buffer = cls(
qlen=qlen,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
lora_rank=lora_rank,
dtype=dtype,
)
return cls._shared_buffer
@classmethod
def clear_cache(cls) -> None:
"""Clear the shared buffer."""
cls._shared_buffer = None
class BaseSFTMoEWrapper(_MoEBase, ABC):
"""
Base class for SFT MoE CPU operations with concrete buffer management.
Subclasses implement:
- _make_forward_task(buffer, save_for_backward) -> C++ task object
- _make_backward_task(buffer) -> C++ task object
- load_weights(physical_to_logical_map_cpu)
- init_lora_weights(...)
- update_lora_weights()
"""
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
lora_rank: int = 16,
lora_alpha: float = 32.0,
max_cache_depth: int = 1,
):
self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count)
self._validate_base_config(
num_experts=num_experts,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_experts_per_tok=num_experts_per_tok,
)
self._validate_sft_config(lora_rank, lora_alpha, max_cache_depth)
self.layer_idx = layer_idx
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_size = hidden_size
self.moe_intermediate_size = moe_intermediate_size
self.num_gpu_experts = num_gpu_experts
self.weight_path = weight_path
self.chunked_prefill_size = chunked_prefill_size
self.threadpool_count = threadpool_count
self.lora_rank = lora_rank
self.lora_alpha = lora_alpha
self.lora_scaling = lora_alpha / lora_rank
self.max_cache_depth = max_cache_depth
self.gate_lora_a: Optional[torch.Tensor] = None
self.gate_lora_b: Optional[torch.Tensor] = None
self.up_lora_a: Optional[torch.Tensor] = None
self.up_lora_b: Optional[torch.Tensor] = None
self.down_lora_a: Optional[torch.Tensor] = None
self.down_lora_b: Optional[torch.Tensor] = None
self._weights_loaded: bool = False
self._lora_initialized: bool = False
self._cache_depth: int = 0
self._is_skip_lora: bool = False
self.moe = None
@staticmethod
def _validate_sft_config(lora_rank: int, lora_alpha: float, max_cache_depth: int) -> None:
if lora_rank <= 0:
raise ValueError(f"lora_rank must be positive, got {lora_rank}")
if lora_alpha <= 0:
raise ValueError(f"lora_alpha must be positive, got {lora_alpha}")
if max_cache_depth <= 0:
raise ValueError(f"max_cache_depth must be positive, got {max_cache_depth}")
# ========== Abstract methods for subclasses ==========
@abstractmethod
def _make_forward_task(self, buffer: KExpertsSFTBuffer, save_for_backward: bool):
"""Construct the C++ forward task object. Backend-specific."""
...
@abstractmethod
def _make_backward_task(self, buffer: KExpertsSFTBuffer):
"""Construct the C++ backward task object. Backend-specific."""
...
@abstractmethod
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None:
...
@abstractmethod
def init_lora_weights(
self,
gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor,
up_lora_a: torch.Tensor, up_lora_b: torch.Tensor,
down_lora_a: torch.Tensor, down_lora_b: torch.Tensor,
grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor,
grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor,
grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor,
) -> None:
...
@abstractmethod
def update_lora_weights(self) -> None:
...
# ========== Buffer helpers ==========
def _get_buffer(self, qlen: int) -> KExpertsSFTBuffer:
return KExpertsSFTBuffer.get_buffer(
qlen=qlen,
hidden_size=self.hidden_size,
moe_intermediate_size=self.moe_intermediate_size,
num_experts=self.num_experts,
num_experts_per_tok=self.num_experts_per_tok,
lora_rank=self.lora_rank,
dtype=torch.bfloat16,
)
def _validate_forward_inputs(self, hidden_states: torch.Tensor, expert_ids: torch.Tensor, weights: torch.Tensor):
if not self._weights_loaded:
raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.")
if not self._lora_initialized and not self._is_skip_lora:
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
qlen = hidden_states.shape[0]
if qlen > self.chunked_prefill_size:
raise ValueError(
f"qlen ({qlen}) exceeds chunked_prefill_size ({self.chunked_prefill_size}). "
"Increase chunked_prefill_size or reduce qlen to avoid buffer overrun."
)
if expert_ids.shape[0] != qlen or expert_ids.shape[1] != self.num_experts_per_tok:
raise ValueError(
f"expert_ids shape {tuple(expert_ids.shape)} must be ({qlen}, {self.num_experts_per_tok})."
)
if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok:
raise ValueError(
f"weights shape {tuple(weights.shape)} must be ({qlen}, {self.num_experts_per_tok})."
)
def _copy_inputs_to_buffer(self, buffer: KExpertsSFTBuffer, hidden_states: torch.Tensor,
expert_ids: torch.Tensor, weights: torch.Tensor, qlen: int) -> torch.device:
"""Copy inputs to CPU buffer, return input device."""
input_device = hidden_states.device
buffer.input_cpu[:qlen].copy_(hidden_states.to(torch.bfloat16), non_blocking=True)
buffer.expert_ids_cpu[:qlen].copy_(expert_ids.to(torch.int64), non_blocking=True)
buffer.weights_cpu[:qlen].copy_(weights.to(torch.float32), non_blocking=True)
buffer.bsz_tensor[0] = qlen
if input_device.type == "cuda":
torch.cuda.synchronize(input_device)
return input_device
def _copy_grad_output_to_cpu(self, buffer: KExpertsSFTBuffer, grad_output: torch.Tensor, qlen: int):
"""Copy grad_output to CPU buffer."""
input_device = grad_output.device
if input_device.type == "cuda":
torch.cuda.synchronize(input_device)
buffer.grad_output_cpu[:qlen].copy_(grad_output.to(torch.bfloat16))
def _return_output(self, buffer: KExpertsSFTBuffer, qlen: int, output_device: Optional[torch.device]):
if output_device is not None:
return buffer.output_cpu[:qlen].to(device=output_device, non_blocking=True)
else:
return buffer.output_cpu[:qlen].clone()
def _return_grads(self, buffer: KExpertsSFTBuffer, qlen: int, output_device: Optional[torch.device]):
if output_device is not None:
grad_input = buffer.grad_input_cpu[:qlen].to(device=output_device, non_blocking=True)
grad_weights = buffer.grad_weights[:qlen].to(device=output_device, non_blocking=True)
else:
grad_input = buffer.grad_input_cpu[:qlen].clone()
grad_weights = buffer.grad_weights[:qlen].clone()
return grad_input, grad_weights
# ========== Concrete forward/backward ==========
def forward(
self,
hidden_states: torch.Tensor,
expert_ids: torch.Tensor,
weights: torch.Tensor,
save_for_backward: bool = True,
output_device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Synchronous forward pass with optional gradient caching."""
self._validate_forward_inputs(hidden_states, expert_ids, weights)
qlen = hidden_states.shape[0]
buffer = self._get_buffer(qlen)
self._copy_inputs_to_buffer(buffer, hidden_states, expert_ids, weights, qlen)
self.cpu_infer.submit(self._make_forward_task(buffer, save_for_backward))
self.cpu_infer.sync()
if save_for_backward and self._cache_depth == 0:
self._cache_depth += 1
return self._return_output(buffer, qlen, output_device)
def backward(
self,
grad_output: torch.Tensor,
output_device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Backward pass computing grad_input and grad_weights."""
if self._cache_depth <= 0:
raise RuntimeError("No forward cache available. Call forward(save_for_backward=True) first.")
qlen = grad_output.shape[0]
buffer = self._get_buffer(qlen)
self._copy_grad_output_to_cpu(buffer, grad_output, qlen)
self.cpu_infer.submit(self._make_backward_task(buffer))
self.cpu_infer.sync()
self._cache_depth -= 1
return self._return_grads(buffer, qlen, output_device)
# ========== Async forward ==========
def submit_forward(
self,
hidden_states: torch.Tensor,
expert_ids: torch.Tensor,
weights: torch.Tensor,
save_for_backward: bool = True,
) -> None:
"""Submit forward pass asynchronously (non-blocking). Call sync_forward() to get results."""
self._validate_forward_inputs(hidden_states, expert_ids, weights)
qlen = hidden_states.shape[0]
buffer = self._get_buffer(qlen)
self._copy_inputs_to_buffer(buffer, hidden_states, expert_ids, weights, qlen)
self._pending_buffer = buffer
self._pending_save_for_backward = save_for_backward
self._pending_qlen = qlen
self.cpu_infer.submit(self._make_forward_task(buffer, save_for_backward))
def sync_forward(self, output_device: Optional[torch.device] = None) -> torch.Tensor:
"""Synchronize and retrieve forward results. Must be called after submit_forward()."""
if not hasattr(self, "_pending_buffer") or self._pending_buffer is None:
raise RuntimeError("No pending forward. Call submit_forward() first.")
self.cpu_infer.sync()
buffer = self._pending_buffer
save_for_backward = self._pending_save_for_backward
qlen = self._pending_qlen
if save_for_backward and self._cache_depth == 0:
self._cache_depth += 1
self._pending_buffer = None
self._pending_save_for_backward = None
self._pending_qlen = None
return self._return_output(buffer, qlen, output_device)
# ========== Async backward ==========
def submit_backward_async(
self,
grad_output: torch.Tensor,
output_device: Optional[torch.device] = None,
) -> None:
"""Submit backward task without waiting. Call sync_backward() for results."""
if self._cache_depth <= 0:
raise RuntimeError("No forward cache available. Call forward(save_for_backward=True) first.")
qlen = grad_output.shape[0]
buffer = self._get_buffer(qlen)
self._copy_grad_output_to_cpu(buffer, grad_output, qlen)
self.cpu_infer.submit(self._make_backward_task(buffer))
self._async_bwd_qlen = qlen
self._async_bwd_output_device = output_device
def sync_backward(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Wait for async backward and return results."""
self.cpu_infer.sync()
qlen = self._async_bwd_qlen
output_device = self._async_bwd_output_device
buffer = self._get_buffer(qlen)
self._cache_depth -= 1
return self._return_grads(buffer, qlen, output_device)
# ========== Backward repack (optional, subclasses may override) ==========
def submit_backward_repack(self):
if not self._weights_loaded or self.moe is None:
return
if hasattr(self.moe, 'submit_backward_repack'):
self.moe.submit_backward_repack()
def wait_backward_repack(self):
if not self._weights_loaded or self.moe is None:
return
if hasattr(self.moe, 'wait_backward_repack'):
self.moe.wait_backward_repack()

View file

@ -0,0 +1,139 @@
# KT-Kernel SFT configuration
# SPDX-License-Identifier: Apache-2.0
"""
KTConfig: kt-kernel's own configuration dataclass.
This is the kt-kernel equivalent of DeepSpeed's JSON config —
it holds all kt-kernel-specific settings and is passed through
KTransformersPlugin.kt_config (similar to DeepSpeedPlugin.hf_ds_config).
"""
from __future__ import annotations
import dataclasses
import os
from dataclasses import dataclass, field
from typing import Any, Callable
def _env_int(key: str, default: int | None) -> int | None:
value = os.environ.get(key, None)
if value is None or value == "":
return default
return int(value)
def _env_float(key: str, default: float | None) -> float | None:
value = os.environ.get(key, None)
if value is None or value == "":
return default
return float(value)
def _env_bool(key: str, default: bool) -> bool:
value = os.environ.get(key, None)
if value is None or value == "":
return default
return value.lower() in ("1", "true", "yes")
@dataclass
class KTConfig:
"""
KT-Kernel configuration for SFT training.
All field names use the ``kt_`` prefix so they match the dict keys used in
HfTrainerKTConfig / YAML configs. This means ``KTConfig(**dict)`` works
directly no name-mapping or prefix-stripping needed.
Can be created from:
- Direct construction: KTConfig(kt_backend="AMXBF16", kt_weight_path="/path/...")
- Dict: KTConfig(**config_dict)
- Environment variables: KTConfig() reads ACCELERATE_KT_* env vars as defaults
"""
# Backend selection
kt_backend: str | None = None
kt_num_threads: int | None = None
kt_tp_enabled: bool | None = None
kt_threadpool_count: int | None = None
# Weight loading
kt_weight_path: str | None = None
kt_expert_checkpoint_path: str | None = None
kt_num_gpu_experts: int | None = None
kt_skip_expert_loading: bool | None = None
kt_share_backward_bb: bool | None = None # default True — always saves memory
kt_share_cache_pool: bool | None = None # auto-set by trainer_config_process, not user-facing
# Cache
kt_max_cache_depth: int | None = None
kt_model_max_length: int | None = None
# LoRA
kt_lora_rank: int | None = None
kt_lora_alpha: float | None = None
# LoRA Experts (GPU-side extra experts)
kt_use_lora_experts: bool | None = None
kt_lora_expert_num: int | None = None
kt_lora_expert_intermediate_size: int | None = None
# Runtime state (set during wrapping, not by user)
kt_checkpoint_files: list[str] | None = None
kt_sharded_metadata: dict | None = None
# Custom wrapping
kt_wrap_fn: Callable[..., Any] | None = None
kt_wrap_kwargs: dict[str, Any] | None = None
@classmethod
def from_object(cls, obj: Any) -> "KTConfig":
"""Create KTConfig from an attribute-based object (HfTrainerKTConfig, etc.)."""
_field_names = {f.name for f in dataclasses.fields(cls)}
kwargs: dict[str, Any] = {}
for name in _field_names:
val = getattr(obj, name, None)
if val is not None:
kwargs[name] = val
return cls(**kwargs)
def __post_init__(self):
if self.kt_backend is None:
self.kt_backend = os.environ.get("ACCELERATE_KT_BACKEND", "AMXBF16")
if self.kt_num_threads is None:
self.kt_num_threads = _env_int("ACCELERATE_KT_NUM_THREADS", 1)
if self.kt_tp_enabled is None:
self.kt_tp_enabled = _env_bool("ACCELERATE_KT_TP_ENABLED", False)
if self.kt_threadpool_count is None:
self.kt_threadpool_count = _env_int("ACCELERATE_KT_THREADPOOL_COUNT", 1)
if self.kt_weight_path is None:
self.kt_weight_path = os.environ.get("ACCELERATE_KT_WEIGHT_PATH", None)
if self.kt_expert_checkpoint_path is None:
self.kt_expert_checkpoint_path = os.environ.get("ACCELERATE_KT_EXPERT_CHECKPOINT_PATH", None)
if self.kt_num_gpu_experts is None:
self.kt_num_gpu_experts = _env_int("ACCELERATE_KT_NUM_GPU_EXPERTS", 0)
if self.kt_max_cache_depth is None:
self.kt_max_cache_depth = _env_int("ACCELERATE_KT_MAX_CACHE_DEPTH", 2)
if self.kt_share_backward_bb is None:
self.kt_share_backward_bb = _env_bool("ACCELERATE_KT_SHARE_BACKWARD_BB", True)
if self.kt_share_cache_pool is None:
self.kt_share_cache_pool = False
if self.kt_use_lora_experts is None:
self.kt_use_lora_experts = _env_bool("ACCELERATE_KT_USE_LORA_EXPERTS", False)
if self.kt_lora_expert_num is None:
self.kt_lora_expert_num = _env_int("ACCELERATE_KT_LORA_EXPERT_NUM", None)
if self.kt_lora_expert_intermediate_size is None:
self.kt_lora_expert_intermediate_size = _env_int("ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE", None)
if self.kt_lora_rank is None:
self.kt_lora_rank = _env_int("ACCELERATE_KT_LORA_RANK", None)
if self.kt_lora_alpha is None:
self.kt_lora_alpha = _env_float("ACCELERATE_KT_LORA_ALPHA", None)
if self.kt_lora_alpha is None and self.kt_lora_rank is not None:
self.kt_lora_alpha = float(self.kt_lora_rank * 2)
if self.kt_model_max_length is None:
self.kt_model_max_length = _env_int("ACCELERATE_KT_MODEL_MAX_LENGTH", None)
if self.kt_skip_expert_loading is None:
if "ACCELERATE_KT_SKIP_EXPERT_LOADING" in os.environ:
self.kt_skip_expert_loading = _env_bool("ACCELERATE_KT_SKIP_EXPERT_LOADING", True)

View file

@ -0,0 +1,171 @@
# Distributed and checkpoint utilities for SFT
# SPDX-License-Identifier: Apache-2.0
"""
Shared distributed communication and gradient-checkpoint detection helpers.
This is a leaf module no imports from other sft/ submodules.
"""
from __future__ import annotations
from contextlib import nullcontext
from typing import Any
import torch
def _all_gather_qlens(local_qlen: int, device: torch.device, world_size: int) -> list[int]:
import torch.distributed as dist
local_qlen_t = torch.tensor([int(local_qlen)], device=device, dtype=torch.int64)
gathered = [torch.empty(1, device=device, dtype=torch.int64) for _ in range(world_size)]
dist.all_gather(gathered, local_qlen_t)
return [int(t.item()) for t in gathered]
def _qlen_offsets(all_qlens: list[int]) -> list[int]:
offsets = [0]
for q in all_qlens:
offsets.append(offsets[-1] + int(q))
return offsets
def _dist_gather_varlen_to_rank0(
local_tensor: torch.Tensor,
*,
all_qlens: list[int],
rank: int,
world_size: int,
) -> list[torch.Tensor] | None:
import torch.distributed as dist
local_tensor = local_tensor.contiguous()
local_expected = int(all_qlens[rank])
if local_tensor.shape[0] != local_expected:
raise RuntimeError(
f"Local leading dim mismatch on rank {rank}: got {local_tensor.shape[0]}, expected {local_expected}"
)
if rank == 0:
gathered: list[torch.Tensor | None] = [None] * world_size
gathered[0] = local_tensor
ops: list[dist.P2POp] = []
for src in range(1, world_size):
qlen_src = int(all_qlens[src])
recv_shape = (qlen_src, *local_tensor.shape[1:])
recv = torch.empty(recv_shape, device=local_tensor.device, dtype=local_tensor.dtype)
gathered[src] = recv
if qlen_src > 0:
ops.append(dist.P2POp(dist.irecv, recv, src))
if ops:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
out: list[torch.Tensor] = []
for idx, t in enumerate(gathered):
if t is None:
raise RuntimeError(f"Missing gathered tensor for rank {idx} on rank0.")
out.append(t)
return out
if local_expected > 0:
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, local_tensor, 0)])
for req in reqs:
req.wait()
return None
def _dist_scatter_varlen_from_rank0(
*,
rank0_chunks: list[torch.Tensor] | None,
all_qlens: list[int],
rank: int,
world_size: int,
feature_shape: tuple[int, ...],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
import torch.distributed as dist
local_qlen = int(all_qlens[rank])
local_out = torch.empty((local_qlen, *feature_shape), device=device, dtype=dtype)
if rank == 0:
if rank0_chunks is None or len(rank0_chunks) != world_size:
raise RuntimeError("rank0_chunks must contain one chunk per rank on rank0.")
if int(rank0_chunks[0].shape[0]) != local_qlen:
raise RuntimeError(
f"Rank0 local chunk mismatch: got {rank0_chunks[0].shape[0]}, expected {local_qlen}"
)
if local_qlen > 0:
local_out.copy_(rank0_chunks[0])
ops: list[dist.P2POp] = []
for dst in range(1, world_size):
qlen_dst = int(all_qlens[dst])
if qlen_dst <= 0:
continue
chunk = rank0_chunks[dst].contiguous()
if int(chunk.shape[0]) != qlen_dst:
raise RuntimeError(
f"Rank{dst} chunk mismatch on rank0: got {chunk.shape[0]}, expected {qlen_dst}"
)
ops.append(dist.P2POp(dist.isend, chunk, dst))
if ops:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return local_out
if local_qlen > 0:
reqs = dist.batch_isend_irecv([dist.P2POp(dist.irecv, local_out, 0)])
for req in reqs:
req.wait()
return local_out
def _checkpoint_hook_mode() -> str:
"""Infer checkpoint phase from current saved_tensors_hooks top.
Returns one of:
- "first_forward": non-reentrant checkpoint's _checkpoint_hook
- "recompute": non-reentrant checkpoint's _recomputation_hook
- "none": no default saved_tensors_hooks on top
- "other": unknown hook stack entry
- "error": failed to query hook stack
"""
try:
top = torch._C._autograd._top_saved_tensors_default_hooks(False)
except Exception:
return "error"
if top is None:
return "none"
try:
pack_fn, _ = top
mod = getattr(pack_fn, "__module__", "")
qual = getattr(pack_fn, "__qualname__", getattr(pack_fn, "__name__", ""))
tag = f"{mod}.{qual}"
except Exception:
return "other"
if "_recomputation_hook.__init__.<locals>.pack_hook" in tag:
return "recompute"
if "_checkpoint_hook.__init__.<locals>.pack_hook" in tag:
return "first_forward"
return "other"
def _maybe_zero3_gathered_parameters(params: list[torch.nn.Parameter]):
if not params:
return nullcontext()
try:
from transformers.integrations import is_deepspeed_zero3_enabled
except Exception:
return nullcontext()
if not is_deepspeed_zero3_enabled():
return nullcontext()
try:
import deepspeed # type: ignore
except Exception:
return nullcontext()
return deepspeed.zero.GatheredParameters(params, modifier_rank=0)

View file

@ -0,0 +1,399 @@
# KTMoELayerWrapper — nn.Module replacing HF MoE layers for SFT
# SPDX-License-Identifier: Apache-2.0
"""
KTMoELayerWrapper: drop-in nn.Module replacement for HuggingFace MoE layers.
Delegates expert computation to the C++ KTMoEWrapper backend, with support
for gradient checkpointing, PEFT LoRA on experts, LoRA Experts (separate
small MLPs on GPU), shared experts, and multi-GPU rank-0-only execution.
"""
from __future__ import annotations
import logging
import os
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from .arch import MOEArchConfig
from .autograd import KTMoEFunction
from .dist_utils import (
_all_gather_qlens,
_checkpoint_hook_mode,
_dist_gather_varlen_to_rank0,
_dist_scatter_varlen_from_rank0,
_qlen_offsets,
)
logger = logging.getLogger(__name__)
_KT_SFT_DEBUG = os.environ.get("KT_SFT_DEBUG", "0") == "1"
class KTMoELayerWrapper(nn.Module):
"""Wrapper for MoE layer using KTMoEWrapper."""
def __init__(
self,
original_moe: nn.Module,
wrapper: Any,
lora_params: dict[str, nn.Parameter] | None, # Kept for backward compatibility, but ignored
moe_config: MOEArchConfig,
hidden_size: int,
layer_idx: int,
lora_experts: "LoRAExperts | None" = None,
):
super().__init__()
self._is_kt_moe_wrapper = True
self.wrapper = wrapper
self.moe_config = moe_config
self.hidden_size = hidden_size
self.layer_idx = layer_idx
self.router_type = moe_config.router_type
# IMPORTANT: Register submodules in the SAME ORDER as original MoE module
# so that PEFT's named_modules() traversal order matches baseline.
# This ensures kaiming_uniform_ calls happen in the same sequence.
# Qwen3MoeSparseMoeBlock order: gate FIRST, then experts.
# 1. gate/router FIRST - keep original attribute name for PEFT compatibility
router_attr = moe_config.router_attr # "gate" for Qwen3/DeepSeek
setattr(self, router_attr, getattr(original_moe, router_attr, None))
self._router_attr = router_attr
# 2. experts SECOND (this is what PEFT targets for LoRA)
experts_attr = moe_config.experts_attr # typically "experts"
setattr(self, experts_attr, getattr(original_moe, experts_attr, None))
self._experts_attr = experts_attr
# 3. shared_experts (if any)
if moe_config.has_shared_experts and hasattr(original_moe, "shared_experts"):
self.shared_experts = original_moe.shared_experts
else:
self.shared_experts = None
# 4. lora_experts (separate LoRA expert MLPs, different from PEFT LoRA on experts)
self.lora_experts = lora_experts
# PEFT LoRA tracking (set by kt_adapt_peft_lora)
# _peft_lora_modules: {expert_idx: {proj_name: (lora_A, lora_B)}}
self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None
self._lora_pointers_dirty = False
def _apply(self, fn, recurse=True):
# Protect experts from device transfer (PEFT LoRA should stay on CPU for KT)
saved_experts = None
experts_attr = getattr(self, '_experts_attr', None)
if experts_attr is not None and getattr(self, experts_attr, None) is not None:
saved_experts = getattr(self, experts_attr)
self._modules.pop(experts_attr, None)
result = super()._apply(fn, recurse)
if saved_experts is not None:
self._modules[experts_attr] = saved_experts
return result
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
import torch.distributed as dist
dist_on = dist.is_initialized() and dist.get_world_size() > 1
rank = dist.get_rank() if dist.is_initialized() else 0
# Check if we need to use distributed broadcast (only rank 0 has KT kernel)
use_broadcast = dist_on and self.wrapper is None
topk_ids, topk_weights = self._compute_routing(hidden_states)
train_lora = self._peft_lora_modules is not None and len(self._peft_lora_modules) > 0
save_for_backward = (
self.training
and torch.is_grad_enabled()
and (hidden_states.requires_grad or topk_weights.requires_grad or train_lora)
)
use_autograd_path = save_for_backward
save_for_backward_submit = use_autograd_path
if _checkpoint_hook_mode() == "first_forward":
save_for_backward_submit = False
if train_lora and self._lora_pointers_dirty:
self.update_lora_pointers()
self._lora_pointers_dirty = False
gpu_output, all_qlens = self._submit_and_compute_gpu(
hidden_states,
topk_ids,
topk_weights,
save_for_backward_submit,
)
# Use KTMoEFunction whenever backward is needed so KT backward and LoRA
# gradient paths remain connected.
if use_autograd_path:
lora_ref = hidden_states.new_empty(())
if train_lora and self._peft_lora_modules:
for expert_loras in self._peft_lora_modules.values():
for lora_A, lora_B in expert_loras.values():
if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad:
lora_ref = lora_A.weight
break
if lora_ref.numel() > 0:
break
moe_output = KTMoEFunction.apply(
hidden_states,
topk_ids,
topk_weights,
self.wrapper,
lora_ref,
self.hidden_size,
self.moe_config.num_experts_per_tok,
self.layer_idx,
save_for_backward,
train_lora,
all_qlens,
)
else:
moe_output = self._sync_forward_output_no_autograd(
hidden_states=hidden_states,
all_qlens=all_qlens,
)
if gpu_output is not None:
moe_output = moe_output + gpu_output
return moe_output
def _sync_forward_output_no_autograd(
self,
hidden_states: torch.Tensor,
all_qlens: list[int] | tuple[int, ...] | None,
) -> torch.Tensor:
"""Sync CPU expert output without creating KTMoEFunction autograd nodes."""
import torch.distributed as dist
original_device = hidden_states.device
original_dtype = hidden_states.dtype
batch_size, seq_len, _ = hidden_states.shape
qlen = batch_size * seq_len
dist_on = dist.is_initialized() and dist.get_world_size() > 1
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist_on else 1
if dist_on:
if all_qlens is None:
all_qlens_list = _all_gather_qlens(qlen, original_device, world_size)
else:
all_qlens_list = [int(q) for q in all_qlens]
if len(all_qlens_list) != world_size:
raise RuntimeError(
f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}"
)
if int(all_qlens_list[rank]) != qlen:
raise RuntimeError(
f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}"
)
total_qlen = sum(all_qlens_list)
if rank == 0:
if self.wrapper is None:
raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.")
cpu_output = self.wrapper.sync_forward(output_device=original_device)
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size)
offsets = _qlen_offsets(all_qlens_list)
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
else:
scatter_list = None
output_flat = _dist_scatter_varlen_from_rank0(
rank0_chunks=scatter_list,
all_qlens=all_qlens_list,
rank=rank,
world_size=world_size,
feature_shape=(self.hidden_size,),
device=original_device,
dtype=original_dtype,
)
output = output_flat.view(batch_size, seq_len, self.hidden_size)
del output_flat
return output
if self.wrapper is not None:
cpu_output = self.wrapper.sync_forward(output_device=original_device)
output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype)
return output
return torch.empty(batch_size, seq_len, self.hidden_size, device=original_device, dtype=original_dtype)
def _compute_routing(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Run routing under no_grad to avoid creating autograd nodes whose
# SavedVariables become orphan holders inside gradient checkpoint.
# The gate is frozen during LoRA fine-tuning and the main gradient
# flows through KTMoEFunction.backward()'s grad_input, so the
# routing gradient contribution to hidden_states can be safely dropped.
with torch.no_grad():
router = getattr(self, self._router_attr)
if self.router_type == "deepseek_gate":
# DeepSeek V3's MoEGate has `assert not self.training` in its noaux_tc
# routing path because the HF model is an inference-only port.
# For LoRA fine-tuning the router is frozen, so eval() is safe.
was_training = router.training
if was_training:
router.eval()
router_output = router(hidden_states)
if was_training:
router.train()
if len(router_output) == 2:
topk_ids, topk_weights = router_output
else:
topk_ids, topk_weights = router_output[0], router_output[1]
if topk_weights.is_floating_point():
topk_weights = topk_weights.to(torch.bfloat16)
return topk_ids, topk_weights
router_output = router(hidden_states.view(-1, self.hidden_size))
# transformers v5 TopKRouter returns (router_logits, router_scores, router_indices)
# directly — scores/indices are already topk-normalized.
if isinstance(router_output, tuple):
if len(router_output) >= 3:
_logits, topk_weights, topk_ids = router_output[0], router_output[1], router_output[2]
if topk_weights.is_floating_point():
topk_weights = topk_weights.to(torch.bfloat16)
return topk_ids, topk_weights
router_output = router_output[0]
router_logits = router_output
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, self.moe_config.num_experts_per_tok, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(torch.bfloat16)
return topk_ids, topk_weights
def _submit_and_compute_gpu(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
save_for_backward: bool,
) -> tuple[torch.Tensor | None, list[int] | None]:
import torch.distributed as dist
batch_size, seq_len, _ = hidden_states.shape
original_device = hidden_states.device
original_dtype = hidden_states.dtype
dist_on = dist.is_initialized() and dist.get_world_size() > 1
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist_on else 1
qlen = batch_size * seq_len
if dist_on:
all_qlens = _all_gather_qlens(qlen, original_device, world_size)
if int(all_qlens[rank]) != qlen:
raise RuntimeError(
f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}"
)
total_qlen = sum(all_qlens)
hs_flat = hidden_states.view(qlen, self.hidden_size).contiguous()
expert_ids = topk_ids.view(qlen, self.moe_config.num_experts_per_tok).contiguous()
weights = topk_weights.view(qlen, self.moe_config.num_experts_per_tok).contiguous()
submit_hs = hs_flat.detach()
submit_ids = expert_ids.detach()
submit_wts = weights.detach()
gathered_hs = _dist_gather_varlen_to_rank0(
submit_hs,
all_qlens=all_qlens,
rank=rank,
world_size=world_size,
)
gathered_ids = _dist_gather_varlen_to_rank0(
submit_ids,
all_qlens=all_qlens,
rank=rank,
world_size=world_size,
)
gathered_wts = _dist_gather_varlen_to_rank0(
submit_wts,
all_qlens=all_qlens,
rank=rank,
world_size=world_size,
)
if rank == 0:
all_hs = torch.cat(gathered_hs, dim=0)
all_ids = torch.cat(gathered_ids, dim=0)
all_wts = torch.cat(gathered_wts, dim=0)
self.wrapper.submit_forward(
all_hs,
all_ids,
all_wts,
save_for_backward=save_for_backward,
)
# Keep shared/lora experts local to avoid qlen_max-style amplification.
gpu_output = None
if self.shared_experts is not None:
gpu_output = self.shared_experts(hidden_states)
gpu_output = gpu_output.to(dtype=original_dtype)
if self.lora_experts is not None:
lora_out = self.lora_experts(hidden_states)
gpu_output = lora_out if gpu_output is None else gpu_output + lora_out
return gpu_output, all_qlens
else:
# ---- Single-GPU path: submit + GPU compute ----
input_flat = hidden_states.view(qlen, self.hidden_size)
expert_ids = topk_ids.view(qlen, self.moe_config.num_experts_per_tok)
weights = topk_weights.view(qlen, self.moe_config.num_experts_per_tok)
# Avoid passing graph-attached tensors into C++ cache.
submit_hs = input_flat.detach()
submit_ids = expert_ids.detach()
submit_wts = weights.detach()
self.wrapper.submit_forward(
submit_hs,
submit_ids,
submit_wts,
save_for_backward=save_for_backward,
)
# GPU compute: shared_experts + lora_experts
gpu_output = None
if self.shared_experts is not None:
gpu_output = self.shared_experts(hidden_states)
if self.lora_experts is not None:
lora_out = self.lora_experts(hidden_states)
gpu_output = lora_out if gpu_output is None else gpu_output + lora_out
return gpu_output, None
def update_lora_pointers(self):
"""Sync PEFT LoRA weights to C++ kernel after optimizer update."""
# Skip if wrapper is None (non-rank-0 processes)
if self.wrapper is None:
return
# Skip if wrapper is not properly initialized
if not getattr(self.wrapper, "_weights_loaded", False):
logger.warning(f"Layer {self.layer_idx}: Skipping update_lora_pointers - weights not loaded")
return
if not getattr(self.wrapper, "_lora_initialized", False):
logger.warning(f"Layer {self.layer_idx}: Skipping update_lora_pointers - LoRA not initialized")
return
# PEFT weights are views into wrapper's contiguous buffers —
# optimizer.step() already updated them in-place, just re-sync to C++.
self.wrapper.update_lora_weights()

View file

@ -0,0 +1,803 @@
# PEFT LoRA adaptation utilities for SFT
# SPDX-License-Identifier: Apache-2.0
"""
PEFT LoRA integration for KT-Kernel MoE training.
Handles:
- LoRA Expert modules (LoRAExpertMLP, LoRAExperts)
- PEFT LoRA adaptation onto KT wrappers (contiguous buffer views, grad buffers)
- LoRA parameter collection for optimizer injection
- Checkpoint save/load for lora_experts
"""
from __future__ import annotations
import logging
import math
import os
import re
import torch
import torch.nn as nn
from .arch import MOEArchConfig
logger = logging.getLogger(__name__)
# =============================================================================
# LoRA Experts Modules
# =============================================================================
class LoRAExpertMLP(nn.Module):
"""Single LoRA Expert with SwiGLU activation structure."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.le_gate = nn.Linear(hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.le_up = nn.Linear(hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.le_down = nn.Linear(intermediate_size, hidden_size, bias=False, device=device, dtype=dtype)
self.act_fn = nn.SiLU()
nn.init.zeros_(self.le_down.weight)
nn.init.kaiming_uniform_(self.le_gate.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.le_up.weight, a=math.sqrt(5))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.le_down(self.act_fn(self.le_gate(x)) * self.le_up(x))
class LoRAExperts(nn.Module):
"""LoRA Experts module containing multiple LoRA Expert MLPs."""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.experts = nn.ModuleList(
[LoRAExpertMLP(hidden_size, intermediate_size, device, dtype) for _ in range(num_experts)]
)
self.num_experts = num_experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(hidden_states)
for expert in self.experts:
output = output + expert(hidden_states)
return output / self.num_experts
# =============================================================================
# LoRA Parameter Collection
# =============================================================================
def _find_kt_wrappers(model: nn.Module):
"""Find _kt_wrappers on model, unwrapping PEFT/other wrappers if needed."""
wrappers = getattr(model, "_kt_wrappers", None)
if wrappers is None:
base_model = model
for attr in ("base_model", "model"):
if hasattr(base_model, attr):
base_model = getattr(base_model, attr)
wrappers = getattr(base_model, "_kt_wrappers", None)
if wrappers:
break
return wrappers
def get_kt_lora_params(model: nn.Module) -> list[nn.Parameter]:
"""Get all MoE LoRA parameters from KT model.
Returns PEFT LoRA parameters from expert modules and lora_experts parameters.
"""
params: list[nn.Parameter] = []
wrappers = _find_kt_wrappers(model)
if wrappers:
for wrapper in wrappers:
# PEFT LoRA parameters (from _peft_lora_modules)
peft_lora_modules = getattr(wrapper, "_peft_lora_modules", None)
if peft_lora_modules is not None:
for expert_loras in peft_lora_modules.values():
for lora_A, lora_B in expert_loras.values():
if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad:
params.append(lora_A.weight)
if hasattr(lora_B, 'weight') and lora_B.weight.requires_grad:
params.append(lora_B.weight)
# Fused expert LoRA parameters (KT-managed, not PEFT)
fused_params = getattr(wrapper, "_fused_expert_lora_params", None)
if fused_params is not None:
params.extend(fused_params)
# lora_experts parameters (separate feature)
if getattr(wrapper, "lora_experts", None) is not None:
params.extend(wrapper.lora_experts.parameters())
return params
# =============================================================================
# PEFT LoRA Adaptation
# =============================================================================
def kt_adapt_peft_lora(model: nn.Module) -> None:
"""
Adapt PEFT LoRA on expert modules for KT kernel.
After PEFT injects LoRA adapters onto expert Linear modules, this function:
1. Detects PEFT LoRA presence and rank on each wrapper's experts
2. Stores references to PEFT LoRA modules on the wrapper (for backward gradient writing)
3. Syncs initial PEFT LoRA weights to the C++ KT kernel (rank 0 only)
PEFT LoRA remains active and is managed by PEFT. No separate KT lora_params created.
Optimizer updates PEFT LoRA directly, and KT kernel reads from PEFT LoRA on each forward.
Should be called after PEFT LoRA injection and before create_optimizer.
"""
import torch.distributed as dist
wrappers = _find_kt_wrappers(model)
if not wrappers:
logger.info("[kt_adapt_peft_lora] No _kt_wrappers found, skipping")
return
is_rank_0 = True
if dist.is_initialized():
is_rank_0 = dist.get_rank() == 0
adapted_count = 0
for wrapper in wrappers:
moe_config = wrapper.moe_config
layer_idx = wrapper.layer_idx
experts_attr = getattr(wrapper, "_experts_attr", "experts")
experts = getattr(wrapper, experts_attr, None)
if experts is None:
continue
# Fused experts (transformers v5): PEFT cannot auto-attach LoRA to packed
# nn.Parameter tensors. Create KT-managed LoRA buffers with proper init,
# wrap as nn.Parameter for optimizer, and pre-assign .grad for C++ backward.
if getattr(wrapper, "_fused_experts", False):
lora_rank = getattr(wrapper, "_lora_rank", 1)
lora_buffers, lora_grad_buffers, lora_params = _create_fused_expert_lora_buffers(
wrapper, moe_config, lora_rank, torch.bfloat16,
)
if is_rank_0 and wrapper.wrapper is not None:
all_buffers = {}
all_buffers.update(lora_buffers)
all_buffers.update(lora_grad_buffers)
wrapper.wrapper.init_lora_weights(**all_buffers)
logger.info(
f"[kt_adapt_peft_lora] Layer {layer_idx}: fused expert LoRA "
f"(r={lora_rank}, E={moe_config.expert_num})"
)
wrapper._fused_expert_lora_params = lora_params
wrapper._peft_lora_modules = None
adapted_count += 1
continue
if len(experts) == 0:
continue
# Collect references to PEFT LoRA modules for each expert
# Structure: {expert_idx: {proj_name: (lora_A_module, lora_B_module)}}
peft_lora_modules = {}
gate_name, up_name, down_name = moe_config.weight_names
for expert_idx, expert in enumerate(experts):
expert_loras = {}
for proj_name in (gate_name, up_name, down_name):
proj = getattr(expert, proj_name, None)
if proj is None:
continue
lora_A = getattr(proj, "lora_A", None)
lora_B = getattr(proj, "lora_B", None)
if lora_A is not None and lora_B is not None:
# Get the actual Linear modules (inside ModuleDict if using adapters)
if isinstance(lora_A, nn.ModuleDict):
adapter_name = "default"
active = getattr(proj, "active_adapter", ["default"])
if isinstance(active, (list, tuple)) and active:
adapter_name = active[0]
# ModuleDict doesn't have .get(), use [] with in check
lora_A = lora_A[adapter_name] if adapter_name in lora_A else None
lora_B = lora_B[adapter_name] if adapter_name in lora_B else None
if lora_A is not None and lora_B is not None:
expert_loras[proj_name] = (lora_A, lora_B)
if expert_loras:
peft_lora_modules[expert_idx] = expert_loras
# Store PEFT LoRA references on wrapper
wrapper._peft_lora_modules = peft_lora_modules
if not peft_lora_modules:
raise RuntimeError(
f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. "
f"Check that PEFT lora_target includes expert modules."
)
# Allocate contiguous bf16 buffers and populate with initial PEFT values (all ranks)
lora_buffers = _create_lora_view_buffers(peft_lora_modules, moe_config, torch.bfloat16)
lora_grad_buffers = _create_lora_grad_buffers(peft_lora_modules, moe_config)
# Rank 0: pass buffers to C++ wrapper (init_lora_weights stores them via .contiguous() no-op)
if is_rank_0 and wrapper.wrapper is not None:
# concat lora_buffers and lora_grad_buffers into single dict
lora_buffers.update(lora_grad_buffers)
wrapper.wrapper.init_lora_weights(**lora_buffers)
logger.info(f"[kt_adapt_peft_lora] Layer {layer_idx}: synced PEFT LoRA to C++ kernel")
# All ranks: replace PEFT weights with views into the contiguous buffers
_replace_peft_weights_with_views(peft_lora_modules, lora_buffers, lora_grad_buffers, moe_config)
adapted_count += 1
# After collecting all LoRA references, shrink expert base weight parameters
# from their original shape (e.g. [768, 2048]) to scalar (1,).
# These base weights were already replaced with tiny-storage stride=[0] placeholders
# by _clear_original_expert_weights(). They have correct shape but serve no purpose
# after PEFT injection. FSDP2 broadcasts ALL non-DTensor params, and uses
# torch.empty(param.size()) on non-rank-0 — with the original shape this wastes
# ~28GB+. Shrinking to (1,) reduces broadcast cost to ~30KB total.
shrunk_count = 0
shrunk_saved_bytes = 0
for wrapper in wrappers:
experts_attr = getattr(wrapper, "_experts_attr", "experts")
experts = getattr(wrapper, experts_attr, None)
if experts is None:
continue
if getattr(wrapper, "_fused_experts", False):
continue
for expert in experts:
for param_name, param in list(expert.named_parameters()):
if param.requires_grad:
continue # Skip trainable params (LoRA weights)
try:
storage_bytes = param.data.untyped_storage().nbytes()
except Exception:
continue
if storage_bytes > 2:
continue # Skip non-placeholder params
# This is a tiny-storage placeholder (base weight) — replace with
# a scalar (1,) parameter so FSDP broadcasts only 1 element.
original_numel = param.nelement()
parts = param_name.split(".")
container = expert
for p in parts[:-1]:
container = getattr(container, p)
local_name = parts[-1]
container_params = getattr(container, "_parameters", {})
if isinstance(container_params, dict) and local_name in container_params:
scalar_param = nn.Parameter(
torch.empty(1, dtype=param.dtype, device="cpu"),
requires_grad=False,
)
container_params[local_name] = scalar_param
shrunk_count += 1
shrunk_saved_bytes += (original_numel - 1) * param.element_size()
if shrunk_count > 0:
logger.info(
f"[kt_adapt_peft_lora] Shrunk {shrunk_count} expert base weight params "
f"to shape (1,), FSDP broadcast savings={shrunk_saved_bytes / 1024 / 1024:.1f} MB"
)
logger.info(f"[kt_adapt_peft_lora] Adapted {adapted_count} layers (PEFT LoRA mode)")
# =============================================================================
# Contiguous Buffer Creation
# =============================================================================
def _create_lora_view_buffers(
peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]],
moe_config: MOEArchConfig,
dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
"""
Allocate contiguous buffers and populate with initial PEFT LoRA values.
Returns dict with gate_lora_a, gate_lora_b, up_lora_a, up_lora_b,
down_lora_a, down_lora_b each shape [num_experts, ...].
"""
gate_name, up_name, down_name = moe_config.weight_names
num_experts = moe_config.expert_num
first_expert_loras = peft_lora_modules.get(0, {})
if not first_expert_loras:
raise RuntimeError("No PEFT LoRA found on expert 0")
gate_lora = first_expert_loras.get(gate_name)
if gate_lora is None:
raise RuntimeError(f"No PEFT LoRA found on expert 0 {gate_name}")
lora_rank = gate_lora[0].weight.shape[0]
hidden_size = gate_lora[0].weight.shape[1]
intermediate_size = gate_lora[1].weight.shape[0]
buffers = {
"gate_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"),
"gate_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"),
"up_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"),
"up_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"),
"down_lora_a": torch.zeros(num_experts, lora_rank, intermediate_size, dtype=dtype, device="cpu"),
"down_lora_b": torch.zeros(num_experts, hidden_size, lora_rank, dtype=dtype, device="cpu"),
}
proj_to_keys = {
gate_name: ("gate_lora_a", "gate_lora_b"),
up_name: ("up_lora_a", "up_lora_b"),
down_name: ("down_lora_a", "down_lora_b"),
}
for expert_idx in range(num_experts):
expert_loras = peft_lora_modules.get(expert_idx, {})
for proj_name, (key_a, key_b) in proj_to_keys.items():
if proj_name in expert_loras:
lora_A, lora_B = expert_loras[proj_name]
buffers[key_a][expert_idx].copy_(lora_A.weight.data.to(dtype=dtype))
buffers[key_b][expert_idx].copy_(lora_B.weight.data.to(dtype=dtype))
return buffers
def _create_lora_grad_buffers(
peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]],
moe_config: MOEArchConfig,
dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
"""
Allocate contiguous gradient buffers for PEFT LoRA.
Returns dict with grad_gate_lora_a, grad_gate_lora_b, etc. each shape [num_experts, ...].
"""
gate_name, up_name, down_name = moe_config.weight_names
num_experts = moe_config.expert_num
first_expert_loras = peft_lora_modules.get(0, {})
if not first_expert_loras:
raise RuntimeError("No PEFT LoRA found on expert 0")
gate_lora = first_expert_loras.get(gate_name)
if gate_lora is None:
raise RuntimeError(f"No PEFT LoRA found on expert 0 {gate_name}")
lora_rank = gate_lora[0].weight.shape[0]
hidden_size = gate_lora[0].weight.shape[1]
intermediate_size = gate_lora[1].weight.shape[0]
buffers = {
"grad_gate_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"),
"grad_gate_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"),
"grad_up_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"),
"grad_up_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"),
"grad_down_lora_a": torch.zeros(num_experts, lora_rank, intermediate_size, dtype=dtype, device="cpu"),
"grad_down_lora_b": torch.zeros(num_experts, hidden_size, lora_rank, dtype=dtype, device="cpu"),
}
return buffers
def _create_fused_expert_lora_buffers(
wrapper,
moe_config: MOEArchConfig,
lora_rank: int,
dtype: torch.dtype = torch.bfloat16,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], list[nn.Parameter]]:
"""
Create KT-managed LoRA buffers for fused expert modules.
Fused experts store weights as 3D parameters (gate_up_proj [E, 2I, H], down_proj [E, H, I])
rather than per-expert nn.Linear modules. PEFT can't attach per-expert LoRA to these,
so we create our own LoRA buffers that the C++ kernel reads/writes directly.
Returns:
(lora_buffers, lora_grad_buffers, lora_params):
- lora_buffers: dict of weight buffers for C++ init_lora_weights()
- lora_grad_buffers: dict of grad buffers for C++ backward
- lora_params: list of nn.Parameter wrappers for the optimizer
"""
E = moe_config.expert_num
I = moe_config.intermediate_size
H = wrapper.hidden_size
r = lora_rank
logger.info(f"[_create_fused_expert_lora_buffers] E={E}, I={I}, H={H}, r={r}")
lora_buffers = {
"gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
"gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
"up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
"up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
"down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"),
"down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"),
}
for key in ("gate_lora_a", "up_lora_a", "down_lora_a"):
nn.init.kaiming_uniform_(lora_buffers[key].view(E * r, -1), a=math.sqrt(5))
lora_grad_buffers = {
"grad_gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
"grad_gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
"grad_up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
"grad_up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
"grad_down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"),
"grad_down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"),
}
lora_params = []
for key in ("gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"):
param = nn.Parameter(lora_buffers[key], requires_grad=True)
param.grad = lora_grad_buffers[f"grad_{key}"]
lora_params.append(param)
return lora_buffers, lora_grad_buffers, lora_params
# =============================================================================
# PEFT Weight View Replacement
# =============================================================================
def _replace_peft_weights_with_views(
peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]],
buffers: dict[str, torch.Tensor],
grad_buffers: dict[str, torch.Tensor],
moe_config: MOEArchConfig,
) -> None:
"""
Replace each PEFT LoRA module's .weight with a view into the contiguous buffer.
After this, optimizer.step() updates the buffer in-place via the view
no copy needed to sync with C++.
"""
gate_name, up_name, down_name = moe_config.weight_names
num_experts = moe_config.expert_num
proj_to_keys = {
gate_name: ("gate_lora_a", "gate_lora_b"),
up_name: ("up_lora_a", "up_lora_b"),
down_name: ("down_lora_a", "down_lora_b"),
}
_replaced = 0
_first_logged = False
for expert_idx in range(num_experts):
expert_loras = peft_lora_modules.get(expert_idx, {})
for proj_name, (key_a, key_b) in proj_to_keys.items():
if proj_name not in expert_loras:
continue
lora_A, lora_B = expert_loras[proj_name]
# Log before/after for first replacement to verify .data assignment
if not _first_logged:
_old_id_a = id(lora_A.weight)
_old_ptr_a = lora_A.weight.data_ptr()
# Use .data assignment to keep the same Parameter objects.
# This preserves optimizer references (which point to these objects).
# Creating new nn.Parameter() would break the optimizer link.
lora_A.weight.data = buffers[key_a][expert_idx]
lora_B.weight.data = buffers[key_b][expert_idx]
lora_A.weight.requires_grad_(True)
lora_B.weight.requires_grad_(True)
lora_A.weight.grad = grad_buffers["grad_" + key_a][expert_idx]
lora_B.weight.grad = grad_buffers["grad_" + key_b][expert_idx]
if not _first_logged:
_new_id_a = id(lora_A.weight)
_new_ptr_a = lora_A.weight.data_ptr()
_buf_ptr_a = buffers[key_a][expert_idx].data_ptr()
_has_grad = lora_A.weight.grad is not None
logger.info(
"[_replace_peft_weights_with_views] first param: "
"id %s->%s (same=%s) data_ptr %s->%s buf_ptr=%s (match=%s) "
"has_grad=%s requires_grad=%s shape=%s",
_old_id_a, _new_id_a, _old_id_a == _new_id_a,
_old_ptr_a, _new_ptr_a, _buf_ptr_a, _new_ptr_a == _buf_ptr_a,
_has_grad, lora_A.weight.requires_grad, tuple(lora_A.weight.shape),
)
_first_logged = True
_replaced += 1
logger.info("[_replace_peft_weights_with_views] replaced %d param pairs", _replaced)
# =============================================================================
# Runtime LoRA Pointer Updates
# =============================================================================
def update_kt_lora_pointers(model: nn.Module):
"""Mark KT wrapper LoRA pointers as dirty after optimizer.step()."""
wrappers = _find_kt_wrappers(model)
if wrappers:
for wrapper in wrappers:
wrapper._lora_pointers_dirty = True
# =============================================================================
# Cross-Rank Gradient Synchronization
# =============================================================================
def sync_kt_lora_gradients(model: nn.Module) -> None:
"""
Synchronize KT-managed LoRA gradients across ranks.
KT computes expert LoRA gradients only on rank 0 (gather/scatter path). This function broadcasts the
per-layer contiguous grad buffers from rank 0 to all ranks so that:
- gradient clipping sees identical grads on every rank
- optimizer.step() applies identical updates
"""
import torch.distributed as dist
if not (dist.is_initialized() and dist.get_world_size() > 1):
return
world_size = dist.get_world_size()
if world_size <= 1:
return
params = get_kt_lora_params(model)
if not params:
return
for param in params:
if param.grad is not None:
# Move grad to the same device as the parameter for all-reduce
# Then move back to CPU
original_device = param.grad.device
if original_device.type == "cpu":
# All-reduce on CPU might be slow; consider using a GPU buffer
grad_gpu = param.grad.cuda()
dist.all_reduce(grad_gpu, op=dist.ReduceOp.SUM)
grad_gpu.div_(world_size)
param.grad.copy_(grad_gpu.cpu())
else:
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad.div_(world_size)
# =============================================================================
# Checkpoint Save/Load
# =============================================================================
def save_lora_experts_to_adapter(model: nn.Module, output_dir: str) -> None:
"""
Save LoRA Experts weights to adapter file by merging with existing Attention LoRA.
"""
from safetensors import safe_open
from safetensors.torch import save_file
wrappers = _find_kt_wrappers(model) or []
if not wrappers:
logger.warning("No KT wrappers found, skipping LoRA Experts saving")
return
adapter_file = os.path.join(output_dir, "adapter_model.safetensors")
if not os.path.exists(adapter_file):
adapter_file_bin = os.path.join(output_dir, "adapter_model.bin")
if os.path.exists(adapter_file_bin):
state_dict = torch.load(adapter_file_bin, map_location="cpu", weights_only=True)
else:
logger.warning(f"No existing adapter file found at {output_dir}, creating new one")
state_dict = {}
else:
state_dict = {}
with safe_open(adapter_file, framework="pt") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
lora_expert_count = 0
for wrapper in wrappers:
if wrapper.lora_experts is None:
continue
layer_idx = wrapper.layer_idx
for expert_idx, expert in enumerate(wrapper.lora_experts.experts):
base_key = f"base_model.model.model.layers.{layer_idx}.mlp.lora_experts.{expert_idx}"
state_dict[f"{base_key}.le_gate.weight"] = expert.le_gate.weight.data.cpu().clone()
state_dict[f"{base_key}.le_up.weight"] = expert.le_up.weight.data.cpu().clone()
state_dict[f"{base_key}.le_down.weight"] = expert.le_down.weight.data.cpu().clone()
lora_expert_count += 3
logger.debug(f"Added LoRA Experts for layer {layer_idx} ({len(wrapper.lora_experts.experts)} experts)")
output_file = os.path.join(output_dir, "adapter_model.safetensors")
save_file(state_dict, output_file, metadata={"format": "pt"})
logger.info(
f"Saved LoRA Experts to {output_file}: "
f"{len(wrappers)} layers, {lora_expert_count} LoRA Expert tensors added, "
f"{len(state_dict)} total tensors"
)
def save_kt_moe_to_adapter(model: nn.Module, output_dir: str) -> None:
"""
Unified function to save KT MoE weights to adapter file.
Note: Per-expert PEFT LoRA is saved by PEFT directly, not here.
This function only handles lora_experts (a separate feature).
"""
wrappers = _find_kt_wrappers(model) or []
if not wrappers:
logger.info("[save_kt_moe] No KT wrappers found, skipping")
return
has_lora_experts = any(w.lora_experts is not None for w in wrappers)
has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers)
if has_lora_experts:
save_lora_experts_to_adapter(model, output_dir)
if has_fused_lora:
_save_fused_expert_lora(wrappers, output_dir)
if not has_lora_experts and not has_fused_lora:
logger.info("[save_kt_moe] No lora_experts or fused expert LoRA in KT wrappers")
def _save_fused_expert_lora(wrappers: list, output_dir: str) -> None:
"""Save fused expert LoRA params to a safetensors file."""
from safetensors.torch import save_file
names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"]
tensors = {}
for w in wrappers:
fused = getattr(w, "_fused_expert_lora_params", None)
if fused is None:
continue
for param, name in zip(fused, names):
key = f"layers.{w.layer_idx}.experts.{name}"
tensors[key] = param.data.clone()
if tensors:
path = os.path.join(output_dir, "fused_expert_lora.safetensors")
save_file(tensors, path)
logger.info(f"[save_kt_moe] Saved {len(tensors)} fused expert LoRA tensors to {path}")
def _load_fused_expert_lora(wrappers: list, adapter_path: str) -> None:
"""Load fused expert LoRA params from a safetensors file into existing wrapper buffers."""
path = os.path.join(adapter_path, "fused_expert_lora.safetensors")
if not os.path.isfile(path):
logger.warning(f"No fused_expert_lora.safetensors found at {adapter_path}")
return
from safetensors.torch import load_file
saved = load_file(path)
names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"]
wrapper_map = {w.layer_idx: w for w in wrappers}
loaded_count = 0
for key, tensor in saved.items():
parts = key.split(".")
if len(parts) != 4 or parts[0] != "layers" or parts[2] != "experts":
logger.warning(f"Unexpected key in fused_expert_lora.safetensors: {key}")
continue
layer_idx = int(parts[1])
name = parts[3]
if name not in names:
continue
wrapper = wrapper_map.get(layer_idx)
if wrapper is None:
continue
fused = getattr(wrapper, "_fused_expert_lora_params", None)
if fused is None:
continue
param_idx = names.index(name)
fused[param_idx].data.copy_(tensor)
loaded_count += 1
logger.info(f"[_load_fused_expert_lora] Loaded {loaded_count} tensors from {path}")
def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None:
"""
Load LoRA Experts weights from adapter file into KT wrappers.
"""
from safetensors import safe_open
wrappers = _find_kt_wrappers(model) or []
if not wrappers:
logger.warning("No KT wrappers found, skipping LoRA Experts loading")
return
wrapper_map = {w.layer_idx: w for w in wrappers if w.lora_experts is not None}
if not wrapper_map:
logger.warning("No LoRA Experts found in KT wrappers, skipping")
return
# Prefer dedicated lora_experts file, fallback to adapter file
adapter_file = os.path.join(adapter_path, "lora_experts.safetensors")
if not os.path.exists(adapter_file):
adapter_file = os.path.join(adapter_path, "adapter_model.safetensors")
if not os.path.exists(adapter_file):
adapter_file = os.path.join(adapter_path, "adapter_model.bin")
if not os.path.exists(adapter_file):
logger.warning(f"No lora_experts or adapter file found at {adapter_path}")
return
logger.info(f"Loading LoRA Experts from {adapter_file}")
lora_expert_pattern = re.compile(
r"base_model\.model\.model\.layers\.(\d+)\.mlp\.lora_experts\.(\d+)\.(le_gate|le_up|le_down)\.weight"
)
layer_weights = {}
with safe_open(adapter_file, framework="pt") as f:
for key in f.keys():
match = lora_expert_pattern.match(key)
if match:
layer_idx = int(match.group(1))
expert_idx = int(match.group(2))
proj_name = match.group(3)
layer_weights.setdefault(layer_idx, {}).setdefault(expert_idx, {})[proj_name] = f.get_tensor(key)
loaded_count = 0
for layer_idx, experts_dict in layer_weights.items():
if layer_idx not in wrapper_map:
logger.warning(f"No LoRA Experts for layer {layer_idx}, skipping")
continue
wrapper = wrapper_map[layer_idx]
for expert_idx, proj_dict in experts_dict.items():
if expert_idx >= len(wrapper.lora_experts.experts):
continue
expert = wrapper.lora_experts.experts[expert_idx]
if "le_gate" in proj_dict:
expert.le_gate.weight.data.copy_(proj_dict["le_gate"].to(expert.le_gate.weight.device))
if "le_up" in proj_dict:
expert.le_up.weight.data.copy_(proj_dict["le_up"].to(expert.le_up.weight.device))
if "le_down" in proj_dict:
expert.le_down.weight.data.copy_(proj_dict["le_down"].to(expert.le_down.weight.device))
loaded_count += 1
logger.info(f"Loaded LoRA Experts for {loaded_count} experts from {adapter_path}")
def load_kt_moe_from_adapter(model: nn.Module, adapter_path: str) -> None:
"""
Unified function to load KT MoE weights from adapter file.
Note: Per-expert PEFT LoRA is loaded by PEFT directly, not here.
This function only handles lora_experts (a separate feature).
"""
wrappers = _find_kt_wrappers(model) or []
if not wrappers:
logger.warning("No KT wrappers found, skipping KT MoE loading")
return
has_lora_experts = any(w.lora_experts is not None for w in wrappers)
has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers)
if has_lora_experts:
load_lora_experts_from_adapter(model, adapter_path)
if has_fused_lora:
_load_fused_expert_lora(wrappers, adapter_path)
if not has_lora_experts and not has_fused_lora:
logger.info("No lora_experts or fused expert LoRA in KT wrappers")

View file

@ -0,0 +1,557 @@
# Weight extraction and loading utilities for SFT
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json
import logging
import os
import time
from contextlib import nullcontext
from dataclasses import dataclass
import torch
import torch.nn as nn
from .arch import MOEArchConfig
from .dist_utils import _maybe_zero3_gathered_parameters
logger = logging.getLogger(__name__)
try:
from safetensors import safe_open
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
safe_open = None
# =============================================================================
# Weight Extraction
# =============================================================================
def extract_moe_weights(
moe_module: nn.Module, moe_config: MOEArchConfig
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Extract MoE expert weights from the module.
Returns (gate_proj, up_proj, down_proj) with shape
[expert_num, out_features, in_features].
Supports two formats:
- ModuleList of Linear experts (transformers v4 style)
- Fused Parameters (transformers v5 style): single module with
``gate_up_proj`` [E, 2*I, H] and ``down_proj`` [E, H, I] tensors.
"""
from .arch import detect_fused_experts
experts = getattr(moe_module, moe_config.experts_attr)
# Fused format (transformers v5): a single nn.Module with gate_up_proj/down_proj tensors
if detect_fused_experts(experts):
gate_up = getattr(experts, "gate_up_proj").data
down_fused = getattr(experts, "down_proj").data
# gate_up_proj is [E, 2*I, H], split into gate [E, I, H] and up [E, I, H]
intermediate = gate_up.shape[1] // 2
gate_proj = gate_up[:, :intermediate, :].contiguous()
up_proj = gate_up[:, intermediate:, :].contiguous()
# down_proj is already [E, H, I]
down_proj = down_fused.contiguous()
return gate_proj, up_proj, down_proj
gate_name, up_name, down_name = moe_config.weight_names
gather_params: list[torch.nn.Parameter] = []
for expert in experts:
for weight_name in (gate_name, up_name, down_name):
proj = getattr(expert, weight_name, None)
if proj is not None and hasattr(proj, "weight"):
# Handle PEFT LoRA wrapped modules
weight = proj.weight
if isinstance(weight, torch.Tensor):
gather_params.append(weight)
elif hasattr(weight, "data"):
gather_params.append(weight.data)
with _maybe_zero3_gathered_parameters(gather_params):
gate_weights = []
up_weights = []
down_weights = []
for expert in experts:
# Handle PEFT LoRA wrapped modules - get weight tensor properly
gate_proj = getattr(expert, gate_name)
up_proj_mod = getattr(expert, up_name)
down_proj_mod = getattr(expert, down_name)
# Get weight tensors, handling both regular Linear and PEFT LoRA wrapped
def get_weight_tensor(mod):
weight = mod.weight
if isinstance(weight, torch.Tensor):
return weight.data
elif hasattr(weight, "data"):
return weight.data
else:
raise ValueError(f"Cannot extract weight from {type(mod)}, weight type={type(weight)}")
gate_weights.append(get_weight_tensor(gate_proj))
up_weights.append(get_weight_tensor(up_proj_mod))
down_weights.append(get_weight_tensor(down_proj_mod))
gate_proj = torch.stack(gate_weights, dim=0)
up_proj = torch.stack(up_weights, dim=0)
down_proj = torch.stack(down_weights, dim=0)
return gate_proj, up_proj, down_proj
def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchConfig) -> None:
"""
Clear original expert weights to free memory after KT weights are loaded.
"""
from .arch import detect_fused_experts
experts = getattr(moe_module, moe_config.experts_attr, None)
if experts is None:
return
# Fused format: replace gate_up_proj/down_proj tensors with zero-storage placeholders
if detect_fused_experts(experts):
for name in ("gate_up_proj", "down_proj"):
param = getattr(experts, name, None)
if not isinstance(param, torch.nn.Parameter):
continue
original_dtype = param.dtype
tiny_storage = torch.UntypedStorage(1, device="cpu")
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
tiny_storage, storage_offset=0, size=param.shape,
stride=[0] * len(param.shape),
)
experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False)
return
def _iter_weight_params():
for expert in experts:
for weight_name in moe_config.weight_names:
proj = getattr(expert, weight_name, None)
if proj is None or not hasattr(proj, "weight"):
continue
parametrizations = getattr(proj, "parametrizations", None)
parametrized_weight = getattr(parametrizations, "weight", None) if parametrizations is not None else None
if parametrized_weight is not None:
original = getattr(parametrized_weight, "original", None)
if isinstance(original, torch.nn.Parameter):
yield proj, parametrized_weight, "original", original
continue
direct_weight = getattr(proj, "_parameters", {}).get("weight")
if isinstance(direct_weight, torch.nn.Parameter):
yield proj, proj, "weight", direct_weight
continue
# Fallback: `weight` can be a non-settable property (e.g. parametrizations) or a non-Parameter.
weight_attr = getattr(proj, "weight", None)
if isinstance(weight_attr, torch.nn.Parameter):
yield proj, proj, "weight", weight_attr
gather_params: list[torch.nn.Parameter] = []
for _, _, _, weight_param in _iter_weight_params():
gather_params.append(weight_param)
replaced_count = 0
with _maybe_zero3_gathered_parameters(gather_params):
for proj, container, param_name, weight_param in _iter_weight_params():
original_dtype = weight_param.dtype
# Create a CPU tensor with the correct shape but NO physical memory.
# torch.empty(shape, device="cpu") unfortunately touches pages via the
# allocator, consuming real RSS. Instead, allocate a 1-byte storage and
# use set_ to give it the original shape with zero strides. The tensor
# is "valid" (correct dtype, device, shape) so PEFT can discover
# in/out features, but its storage is essentially zero-cost.
# NOTE: reading element values from this tensor is undefined -- it is
# only used for shape/dtype discovery by PEFT.
tiny_storage = torch.UntypedStorage(1, device="cpu")
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
tiny_storage, storage_offset=0, size=weight_param.shape,
stride=[0] * len(weight_param.shape),
)
new_param = nn.Parameter(fake_tensor, requires_grad=False)
replaced_count += 1
# Avoid `KeyError: attribute 'weight' already exists` for parametrized modules
# where `weight` is a property and the real parameter lives elsewhere.
container_params = getattr(container, "_parameters", {})
if isinstance(container_params, dict) and param_name in container_params:
container_params[param_name] = new_param
continue
if hasattr(container, param_name):
logger.debug(
f"Skipping clearing expert weight {type(proj).__name__}.{param_name}: "
"attribute exists but is not a registered parameter."
)
continue
try:
setattr(container, param_name, new_param)
except Exception as exc:
logger.warning(
f"Failed to clear expert weight {type(proj).__name__}.{param_name}: {exc}"
)
logger.info(f"Replaced {replaced_count} expert weight params")
# =============================================================================
# kt_weight_path Loading Functions
# =============================================================================
@dataclass
class INT8ExpertWeights:
"""Container for INT8 expert weights with scales."""
gate_proj: torch.Tensor
gate_scale: torch.Tensor
up_proj: torch.Tensor
up_scale: torch.Tensor
down_proj: torch.Tensor
down_scale: torch.Tensor
def _find_safetensor_files(kt_weight_path: str) -> list[str]:
if not os.path.isdir(kt_weight_path):
raise FileNotFoundError(f"kt_weight_path directory not found: {kt_weight_path}")
safetensor_files = []
for file in sorted(os.listdir(kt_weight_path)):
if file.endswith(".safetensors"):
safetensor_files.append(os.path.join(kt_weight_path, file))
if not safetensor_files:
raise FileNotFoundError(f"No safetensors files found in {kt_weight_path}")
return safetensor_files
def _load_kt_weight_index(kt_weight_path: str) -> dict[str, str]:
if not SAFETENSORS_AVAILABLE:
raise ImportError("safetensors is required for loading kt_weight_path")
index = {}
safetensor_files = _find_safetensor_files(kt_weight_path)
for file_path in safetensor_files:
with safe_open(file_path, framework="pt") as f:
for key in f.keys():
index[key] = file_path
logger.info(f"Indexed {len(index)} tensors from {len(safetensor_files)} safetensors files")
return index
def _dequant_fp8_experts(weights: list[torch.Tensor], scales: list[torch.Tensor | None], block_size: tuple[int, int]) -> torch.Tensor:
"""Dequantize a list of FP8 expert weights and stack them (batched, vectorized).
Args:
weights: list of [out, in] float8_e4m3fn tensors (one per expert)
scales: list of [out//bs_m, in//bs_n] scale_inv tensors (one per expert, may be None)
block_size: (bs_m, bs_n)
Returns:
Stacked BF16 tensor of shape [num_experts, out, in]
"""
has_scales = scales[0] is not None
if not has_scales:
return torch.stack(weights, dim=0).to(torch.bfloat16).cpu().contiguous()
bs_m, bs_n = block_size
n = len(weights)
out_features, in_features = weights[0].shape
# Stack all experts: [N, out, in] fp8 -> reshape to blocks -> bf16
w = torch.stack(weights, dim=0) # [N, out, in] fp8
w = w.reshape(n, out_features // bs_m, bs_m, in_features // bs_n, bs_n)
w = w.to(torch.bfloat16)
# Stack all scales: [N, out//bs_m, in//bs_n] -> bf16, broadcast multiply
s = torch.stack(scales, dim=0).to(torch.bfloat16) # [N, out//bs_m, in//bs_n]
w = w * s[:, :, None, :, None]
return w.reshape(n, out_features, in_features).contiguous()
def load_experts_from_checkpoint_files(
checkpoint_files: list[str],
sharded_metadata: dict | None,
layers_prefix: str,
moe_config: MOEArchConfig,
layer_idx: int,
block_size: tuple[int, int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not SAFETENSORS_AVAILABLE:
raise ImportError("safetensors is required for loading experts from checkpoint files")
if not checkpoint_files:
raise FileNotFoundError("checkpoint_files is empty")
t0 = time.time()
weight_map = None
base_dir = os.path.dirname(checkpoint_files[0])
if sharded_metadata is not None:
weight_map = sharded_metadata.get("weight_map", None)
gate_name, up_name, down_name = moe_config.weight_names
experts_prefix = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}"
fused_gate_up_key = f"{experts_prefix}.gate_up_proj"
fused_down_key = f"{experts_prefix}.down_proj"
is_fused = weight_map is not None and fused_gate_up_key in weight_map
if is_fused:
keys = [fused_gate_up_key, fused_down_key]
else:
keys = []
for expert_idx in range(moe_config.expert_num):
base = f"{experts_prefix}.{expert_idx}"
keys.append(f"{base}.{gate_name}.weight")
keys.append(f"{base}.{gate_name}.weight_scale_inv")
keys.append(f"{base}.{up_name}.weight")
keys.append(f"{base}.{up_name}.weight_scale_inv")
keys.append(f"{base}.{down_name}.weight")
keys.append(f"{base}.{down_name}.weight_scale_inv")
keys_by_file: dict[str, list[str]] = {}
mapped_count = 0
unmapped_count = 0
for key in keys:
if weight_map is not None:
filename = weight_map.get(key)
if filename is None:
unmapped_count += 1
continue
mapped_count += 1
file_path = os.path.join(base_dir, filename)
else:
file_path = checkpoint_files[0]
keys_by_file.setdefault(file_path, []).append(key)
print(
f"[kt_moe] Layer {layer_idx}: key mapping done in {time.time()-t0:.1f}s — "
f"total_keys={len(keys)}, mapped={mapped_count}, unmapped={unmapped_count}, "
f"files_to_open={len(keys_by_file)}",
flush=True,
)
t1 = time.time()
tensor_map: dict[str, torch.Tensor] = {}
for file_idx, (file_path, file_keys) in enumerate(keys_by_file.items()):
with safe_open(file_path, framework="pt") as f:
available_keys = set(f.keys())
for key in file_keys:
if key in available_keys:
tensor_map[key] = f.get_tensor(key)
if file_idx == 0:
print(
f"[kt_moe] Layer {layer_idx}: first file loaded ({os.path.basename(file_path)}, "
f"{len(file_keys)} keys) in {time.time()-t1:.1f}s",
flush=True,
)
print(
f"[kt_moe] Layer {layer_idx}: all files loaded in {time.time()-t1:.1f}s — "
f"tensor_map has {len(tensor_map)} tensors",
flush=True,
)
t2 = time.time()
if is_fused:
gate_up = tensor_map.get(fused_gate_up_key)
down = tensor_map.get(fused_down_key)
if gate_up is None or down is None:
raise FileNotFoundError(f"Missing fused expert weights for layer {layer_idx}")
gate_up = gate_up.cpu().to(torch.bfloat16).contiguous()
I = gate_up.shape[1] // 2
gate_proj = gate_up[:, :I, :].contiguous()
up_proj = gate_up[:, I:, :].contiguous()
down_proj = down.cpu().to(torch.bfloat16).contiguous()
del gate_up
print(
f"[kt_moe] Layer {layer_idx}: fused expert format — "
f"split gate_up_proj [{gate_proj.shape}] + down [{down_proj.shape}]",
flush=True,
)
print(
f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, "
f"shape={gate_proj.shape}, dequant=0.0s, total={time.time()-t0:.1f}s",
flush=True,
)
return gate_proj, up_proj, down_proj
gate_weights = []
up_weights = []
down_weights = []
gate_scales = []
up_scales = []
down_scales = []
for expert_idx in range(moe_config.expert_num):
base = f"{experts_prefix}.{expert_idx}"
gate_key = f"{base}.{gate_name}.weight"
up_key = f"{base}.{up_name}.weight"
down_key = f"{base}.{down_name}.weight"
if gate_key not in tensor_map or up_key not in tensor_map or down_key not in tensor_map:
raise FileNotFoundError(f"Missing expert weights for layer {layer_idx}, expert {expert_idx}")
gate_weights.append(tensor_map[gate_key])
up_weights.append(tensor_map[up_key])
down_weights.append(tensor_map[down_key])
gate_scales.append(tensor_map.get(f"{base}.{gate_name}.weight_scale_inv"))
up_scales.append(tensor_map.get(f"{base}.{up_name}.weight_scale_inv"))
down_scales.append(tensor_map.get(f"{base}.{down_name}.weight_scale_inv"))
# Check if weights are FP8 and need dequantization
t2 = time.time()
is_fp8 = gate_weights[0].dtype == torch.float8_e4m3fn
if is_fp8:
if block_size is None:
block_size = (128, 128)
print(
f"[kt_moe] Layer {layer_idx}: FP8 expert weights detected, "
f"dequantizing with block_size={block_size} "
f"(has_scales={gate_scales[0] is not None})",
flush=True,
)
gate_proj = _dequant_fp8_experts(gate_weights, gate_scales, block_size)
up_proj = _dequant_fp8_experts(up_weights, up_scales, block_size)
down_proj = _dequant_fp8_experts(down_weights, down_scales, block_size)
else:
gate_proj = torch.stack(gate_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
up_proj = torch.stack(up_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
down_proj = torch.stack(down_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
print(
f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, shape={gate_proj.shape}, "
f"dequant={time.time()-t2:.1f}s, total={time.time()-t0:.1f}s",
flush=True,
)
return gate_proj, up_proj, down_proj
def load_experts_from_kt_weight_path(
kt_weight_path: str,
layer_idx: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
) -> INT8ExpertWeights:
"""Load INT8 preprocessed expert weights from kt_weight_path for a specific layer."""
if not SAFETENSORS_AVAILABLE:
raise ImportError("safetensors is required for loading kt_weight_path")
index = _load_kt_weight_index(kt_weight_path)
numa_count = 0
test_key_prefix = f"blk.{layer_idx}.ffn_gate_exps.0.numa."
for key in index.keys():
if key.startswith(test_key_prefix) and key.endswith(".weight"):
numa_idx = int(key.split("numa.")[1].split(".")[0])
numa_count = max(numa_count, numa_idx + 1)
if numa_count == 0:
raise FileNotFoundError(
f"No weights found for layer {layer_idx} in {kt_weight_path}. "
f"Expected keys like 'blk.{layer_idx}.ffn_gate_exps.0.numa.0.weight'"
)
logger.info(
f"Loading INT8 weights for layer {layer_idx}: {num_experts} experts, {numa_count} NUMA partitions"
)
gate_weights_list = []
gate_scales_list = []
up_weights_list = []
up_scales_list = []
down_weights_list = []
down_scales_list = []
for expert_idx in range(num_experts):
gate_w_parts = []
gate_s_parts = []
for numa_idx in range(numa_count):
w_key = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.numa.{numa_idx}.weight"
s_key = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.numa.{numa_idx}.scale"
if w_key not in index:
raise FileNotFoundError(f"Weight key not found: {w_key}")
with safe_open(index[w_key], framework="pt") as f:
gate_w_parts.append(f.get_tensor(w_key))
gate_s_parts.append(f.get_tensor(s_key))
gate_w = torch.cat(gate_w_parts, dim=0)
gate_s = torch.cat(gate_s_parts, dim=0)
gate_w = gate_w.view(intermediate_size, hidden_size)
gate_weights_list.append(gate_w)
gate_scales_list.append(gate_s)
up_w_parts = []
up_s_parts = []
for numa_idx in range(numa_count):
w_key = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.numa.{numa_idx}.weight"
s_key = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.numa.{numa_idx}.scale"
if w_key not in index:
raise FileNotFoundError(f"Weight key not found: {w_key}")
with safe_open(index[w_key], framework="pt") as f:
up_w_parts.append(f.get_tensor(w_key))
up_s_parts.append(f.get_tensor(s_key))
up_w = torch.cat(up_w_parts, dim=0)
up_s = torch.cat(up_s_parts, dim=0)
up_w = up_w.view(intermediate_size, hidden_size)
up_weights_list.append(up_w)
up_scales_list.append(up_s)
down_w_parts = []
down_s_parts = []
for numa_idx in range(numa_count):
w_key = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.numa.{numa_idx}.weight"
s_key = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.numa.{numa_idx}.scale"
if w_key not in index:
raise FileNotFoundError(f"Weight key not found: {w_key}")
with safe_open(index[w_key], framework="pt") as f:
down_w_parts.append(f.get_tensor(w_key))
down_s_parts.append(f.get_tensor(s_key))
down_w = torch.cat(down_w_parts, dim=0)
down_s = torch.cat(down_s_parts, dim=0)
down_w = down_w.view(hidden_size, intermediate_size)
down_weights_list.append(down_w)
down_scales_list.append(down_s)
gate_proj = torch.stack(gate_weights_list, dim=0)
gate_scale = torch.stack(gate_scales_list, dim=0)
up_proj = torch.stack(up_weights_list, dim=0)
up_scale = torch.stack(up_scales_list, dim=0)
down_proj = torch.stack(down_weights_list, dim=0)
down_scale = torch.stack(down_scales_list, dim=0)
return INT8ExpertWeights(
gate_proj=gate_proj,
gate_scale=gate_scale,
up_proj=up_proj,
up_scale=up_scale,
down_proj=down_proj,
down_scale=down_scale,
)

View file

@ -0,0 +1,588 @@
# Model wrapping entry points for SFT
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import gc
import importlib.util as _u
import logging
import os
from typing import Any
import torch
import torch.nn as nn
from .arch import (
KTAMXConfigError,
KTAMXNotAvailableError,
MOEArchConfig,
_get_layers_prefix,
_get_model_container_and_layers,
get_moe_arch_config,
get_moe_module,
)
from .layer import KTMoELayerWrapper
from .lora import LoRAExperts
from .weights import (
_clear_original_expert_weights,
extract_moe_weights,
load_experts_from_checkpoint_files,
)
logger = logging.getLogger(__name__)
KT_KERNEL_AVAILABLE = _u.find_spec("kt_kernel") is not None
if KT_KERNEL_AVAILABLE:
try:
from kt_kernel.experts import KTMoEWrapper
except Exception:
KTMoEWrapper = None
KT_KERNEL_AVAILABLE = False
else:
KTMoEWrapper = None
# =============================================================================
# Device-map builders
# =============================================================================
def _get_kt_config(kt_plugin: Any):
"""Extract KTConfig from a KTransformersPlugin or compatible object.
KTConfig field names use kt_ prefix, matching the dict keys in
HfTrainerKTConfig exactly no name-mapping needed.
"""
from .config import KTConfig
if isinstance(kt_plugin, KTConfig):
return kt_plugin
kt_config = getattr(kt_plugin, "kt_config", None)
if kt_config is not None and isinstance(kt_config, KTConfig):
return kt_config
return KTConfig.from_object(kt_plugin)
def build_kt_device_map(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]:
"""
Build device_map for KT model loading with hybrid GPU/CPU expert placement.
"""
moe_config = get_moe_arch_config(config)
layers_prefix = _get_layers_prefix(config)
num_layers = config.num_hidden_layers
num_experts = moe_config.expert_num
cfg = _get_kt_config(kt_plugin)
num_gpu_experts = getattr(cfg, "kt_num_gpu_experts", 0) or 0
device_map: dict[str, str | int] = {}
device_map["model.embed_tokens"] = device
device_map["model.norm"] = device
device_map["lm_head"] = device
for layer_idx in range(num_layers):
layer_prefix = f"{layers_prefix}.{layer_idx}"
device_map[layer_prefix] = device
moe_prefix = f"{layer_prefix}.{moe_config.moe_layer_attr}"
for expert_idx in range(num_experts):
expert_key = f"{moe_prefix}.{moe_config.experts_attr}.{expert_idx}"
if expert_idx < num_gpu_experts:
device_map[expert_key] = device
else:
device_map[expert_key] = "cpu"
logger.info(
f"Built KT device_map: {num_gpu_experts} GPU experts, {num_experts - num_gpu_experts} CPU experts"
)
return device_map
def build_kt_device_map_simplified(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]:
"""
Simplified device_map builder: map full layers to GPU, override routed experts to CPU.
"""
moe_config = get_moe_arch_config(config)
layers_prefix = _get_layers_prefix(config)
num_layers = config.num_hidden_layers
cfg = _get_kt_config(kt_plugin)
num_gpu_experts = getattr(cfg, "kt_num_gpu_experts", 0) or 0
device_map: dict[str, str | int] = {}
device_map["model.embed_tokens"] = device
device_map["model.norm"] = device
device_map["lm_head"] = device
for layer_idx in range(num_layers):
layer_prefix = f"{layers_prefix}.{layer_idx}"
device_map[layer_prefix] = device
experts_prefix = f"{layer_prefix}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}"
if num_gpu_experts == 0:
device_map[experts_prefix] = "cpu"
else:
return build_kt_device_map(config, kt_plugin, device=device)
logger.info("Built simplified KT device_map: all layers on GPU, routed experts on CPU")
return device_map
# =============================================================================
# MoE layer wrapping
# =============================================================================
def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KTMoELayerWrapper]:
"""
Replace model's MoE layers with KTMoEWrapper-based wrappers.
Loads expert weights into the C++ KT kernel. No LoRA initialization ---
LoRA is handled by PEFT and later adapted via kt_adapt_peft_lora().
Only rank 0 initializes KT kernel and loads weights.
"""
import torch.distributed as dist
if not KT_KERNEL_AVAILABLE:
raise KTAMXNotAvailableError("kt_kernel not found. Please install kt_kernel to enable KT MoE support.")
# Only rank 0 should initialize KT and load weights
is_rank_0 = True
if dist.is_initialized():
is_rank_0 = dist.get_rank() == 0
moe_config = get_moe_arch_config(model.config)
_text_cfg = getattr(model.config, "text_config", model.config)
hidden_size = _text_cfg.hidden_size
cfg = _get_kt_config(kt_plugin)
# Read lora_rank/lora_alpha for C++ wrapper initialization (buffer allocation only)
lora_rank = getattr(cfg, "kt_lora_rank", 1) or 1
lora_alpha = getattr(cfg, "kt_lora_alpha", 1.0) or 1.0
# Read LoRA Experts configuration
_raw_le = getattr(cfg, "kt_use_lora_experts", None)
use_lora_experts = bool(_raw_le) if _raw_le is not None else False
lora_expert_num = getattr(cfg, "kt_lora_expert_num", 2) or 2
lora_expert_intermediate_size = getattr(cfg, "kt_lora_expert_intermediate_size", 1024) or 1024
if is_rank_0:
logger.info(
f"LoRA Experts config: use_lora_experts={use_lora_experts}, "
f"num={lora_expert_num}, intermediate_size={lora_expert_intermediate_size}"
)
wrappers: list[KTMoELayerWrapper] = []
moe_layer_count = 0
kt_backend_map = {
"AMXBF16": "AMXBF16_SFT",
"AMXINT8": "AMXINT8_SFT",
"AMXINT4": "AMXINT4_SFT",
"AMXBF16_SkipLoRA": "AMXBF16_SFT_SkipLoRA",
"AMXINT8_SkipLoRA": "AMXINT8_SFT_SkipLoRA",
"AMXINT4_SkipLoRA": "AMXINT4_SFT_SkipLoRA",
}
# Build case-insensitive lookup to handle common typos like "SkipLora" vs "SkipLoRA"
_kt_backend_map_lower = {k.lower(): v for k, v in kt_backend_map.items()}
kt_backend = getattr(cfg, "kt_backend", "AMXBF16")
kt_method = kt_backend_map.get(kt_backend) or _kt_backend_map_lower.get(kt_backend.lower(), "AMXBF16_SFT")
if kt_method != kt_backend_map.get(kt_backend):
logger.warning(
f"kt_backend '{kt_backend}' matched via case-insensitive lookup -> '{kt_method}'. "
f"Please use the exact name from: {list(kt_backend_map.keys())}"
)
if "SkipLoRA" in kt_method:
logger.info(f"Using SkipLoRA backend: {kt_method} (MoE LoRA gradients will be skipped)")
threadpool_count = getattr(cfg, "kt_threadpool_count", 1) if getattr(cfg, "kt_tp_enabled", False) else 1
kt_weight_path = getattr(cfg, "kt_weight_path", None)
use_kt_weight_path = kt_weight_path is not None
if use_kt_weight_path:
logger.info(f"Loading INT8 weights from kt_weight_path: {kt_weight_path}")
checkpoint_files = getattr(cfg, "kt_checkpoint_files", None)
sharded_metadata = getattr(cfg, "kt_sharded_metadata", None)
# When kt_expert_checkpoint_path is set, always resolve from it (overrides any existing
# checkpoint_files which may come from AttnOnlyBf16 and lack expert weights).
kt_expert_checkpoint_path = getattr(cfg, "kt_expert_checkpoint_path", None)
if kt_expert_checkpoint_path:
logger.info(f"Resolving expert checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}")
resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=kt_expert_checkpoint_path)
if resolved_files and all(f.endswith(".safetensors") for f in resolved_files):
checkpoint_files = resolved_files
sharded_metadata = resolved_meta
cfg.kt_checkpoint_files = checkpoint_files
cfg.kt_sharded_metadata = sharded_metadata
logger.info(f"Resolved {len(checkpoint_files)} checkpoint files from kt_expert_checkpoint_path")
else:
logger.warning(f"Failed to resolve checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}")
use_checkpoint_files = bool(checkpoint_files) and not use_kt_weight_path
logger.debug(
f"Weight source: kt_weight_path={kt_weight_path!r}, "
f"kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}, "
f"checkpoint_files count={len(checkpoint_files) if checkpoint_files else 0}, "
f"use_kt_weight_path={use_kt_weight_path}, use_checkpoint_files={use_checkpoint_files}"
)
if use_checkpoint_files:
logger.info("Loading expert weights from checkpoint files (online conversion).")
elif use_kt_weight_path and bool(checkpoint_files):
logger.info("BF16 checkpoint files available for backward gradient computation.")
elif (not use_kt_weight_path) and bool(getattr(cfg, "kt_skip_expert_loading", False)):
# If HF expert weights were skipped during `from_pretrained`, we must source expert weights externally.
model_name_or_path = getattr(getattr(model, "config", None), "name_or_path", None)
if model_name_or_path:
resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=model_name_or_path)
if resolved_files and all(f.endswith(".safetensors") for f in resolved_files):
checkpoint_files = resolved_files
sharded_metadata = resolved_meta
cfg.kt_checkpoint_files = checkpoint_files
cfg.kt_sharded_metadata = sharded_metadata
use_checkpoint_files = True
logger.info("KT skip_expert_loading enabled; using checkpoint files for online expert loading.")
if not use_checkpoint_files:
raise KTAMXConfigError(
"KT skip_expert_loading is enabled but no `kt_weight_path` was provided and no safetensors checkpoint "
"files could be resolved for on-the-fly expert loading."
)
import torch.distributed as _dist
_rank = _dist.get_rank() if _dist.is_initialized() else 0
model_container, layers = _get_model_container_and_layers(model, purpose="wrapping")
logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}")
from .arch import detect_fused_experts as _detect_fused
for layer_idx, layer in enumerate(layers):
moe_module = get_moe_module(layer, moe_config)
if moe_module is None:
continue
_layer_experts = getattr(moe_module, moe_config.experts_attr, None)
_layer_is_fused = _detect_fused(_layer_experts)
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method}, fused={_layer_is_fused})")
# Only rank 0 loads weights and initializes KT kernel
gate_proj, up_proj, down_proj = None, None, None
wrapper = None
if is_rank_0:
# Get block_size from quantization_config if available (for FP8 dequant)
_quant_cfg = getattr(model.config, "quantization_config", None)
_block_size = None
if _quant_cfg is not None:
_block_size = getattr(_quant_cfg, "weight_block_size", None)
if use_kt_weight_path:
logger.debug(f"Layer {layer_idx}: forward + backward from kt_weight_path (.kt files)")
elif use_checkpoint_files:
layers_prefix = _get_layers_prefix(model.config)
gate_proj, up_proj, down_proj = load_experts_from_checkpoint_files(
checkpoint_files=checkpoint_files,
sharded_metadata=sharded_metadata,
layers_prefix=layers_prefix,
moe_config=moe_config,
layer_idx=layer_idx,
block_size=_block_size,
)
else:
gate_proj, up_proj, down_proj = extract_moe_weights(moe_module, moe_config)
gate_proj = gate_proj.cpu().to(torch.bfloat16).contiguous()
up_proj = up_proj.cpu().to(torch.bfloat16).contiguous()
down_proj = down_proj.cpu().to(torch.bfloat16).contiguous()
chunked_prefill_size = getattr(cfg, "kt_model_max_length", None)
if chunked_prefill_size is None:
chunked_prefill_size = getattr(model.config, "max_position_embeddings", 4096)
# Only rank 0 creates KTMoEWrapper and loads weights
if is_rank_0:
wrapper = KTMoEWrapper(
layer_idx=layer_idx,
num_experts=moe_config.expert_num,
num_experts_per_tok=moe_config.num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_config.intermediate_size,
gpu_experts_mask=None,
num_gpu_experts=0,
cpuinfer_threads=getattr(cfg, "kt_num_threads", 1),
threadpool_count=threadpool_count,
weight_path=kt_weight_path or "",
chunked_prefill_size=chunked_prefill_size,
method=kt_method,
mode="sft",
lora_rank=lora_rank,
lora_alpha=lora_alpha,
max_cache_depth=getattr(cfg, "kt_max_cache_depth", 2),
)
# Set share_backward_bb and share_cache_pool BEFORE load_weights (config is built during load)
wrapper.share_backward_bb = cfg.kt_share_backward_bb
wrapper.share_cache_pool = cfg.kt_share_cache_pool
physical_to_logical_map = torch.arange(moe_config.expert_num, dtype=torch.int64, device="cpu")
if use_kt_weight_path:
logger.debug(f"Layer {layer_idx}: calling wrapper.load_weights() (C++ direct .kt load)")
wrapper.load_weights(physical_to_logical_map)
else:
logger.debug(
f"Layer {layer_idx}: calling wrapper.load_weights_from_tensors() "
f"(BF16 tensor path, gate_proj shape={gate_proj.shape if gate_proj is not None else None})"
)
wrapper.load_weights_from_tensors(
gate_proj=gate_proj,
up_proj=up_proj,
down_proj=down_proj,
physical_to_logical_map_cpu=physical_to_logical_map,
)
wrapper.gate_proj = None
wrapper.up_proj = None
wrapper.down_proj = None
# Create LoRA Experts if enabled
lora_experts = None
if use_lora_experts:
lora_experts = LoRAExperts(
num_experts=lora_expert_num,
hidden_size=hidden_size,
intermediate_size=lora_expert_intermediate_size,
device="cuda",
dtype=torch.bfloat16,
)
layer_wrapper = KTMoELayerWrapper(
original_moe=moe_module,
wrapper=wrapper,
lora_params=None,
moe_config=moe_config,
hidden_size=hidden_size,
layer_idx=layer_idx,
lora_experts=lora_experts,
)
layer_wrapper._fused_experts = _layer_is_fused
layer_wrapper._lora_rank = lora_rank
setattr(layer, moe_config.moe_layer_attr, layer_wrapper)
# Base weights have been copied into the C++ kernel's internal BufferB format.
# Do not hold a Python-side reference --- it wastes ~1 GB/layer.
del gate_proj, up_proj, down_proj
wrappers.append(layer_wrapper)
moe_layer_count += 1
# Replace original expert weights with meta placeholders.
# Experts remain in the model tree (via wrapper.experts) so PEFT can discover them.
# Rank 0 already copied weights to C++ kernel via load_weights_from_tensors.
_clear_original_expert_weights(moe_module, moe_config)
logger.info(f"Wrapped {moe_layer_count} MoE layers with KTMoEWrapper")
# Link wrappers for async backward repack (higher layer triggers repack for lower)
for i in range(1, len(wrappers)):
if wrappers[i].wrapper is not None and wrappers[i - 1].wrapper is not None:
wrappers[i].wrapper._next_backward_wrapper = wrappers[i - 1].wrapper
if wrappers and wrappers[0].wrapper is not None:
wrappers[0].wrapper._next_backward_wrapper = None
gc.collect()
return wrappers
# =============================================================================
# Plugin builder
# =============================================================================
def _build_kt_plugin_from_args(model_args: Any, finetuning_args: Any | None = None):
"""
Build a KTransformersPlugin from model_args and optional finetuning_args.
Imported here to avoid circular dependency --- callers that need the plugin
class should import it from the appropriate dataclasses module.
"""
from .config import KTConfig
from accelerate.utils.dataclasses import KTransformersPlugin
kt_config = KTConfig(
kt_backend=getattr(model_args, "kt_backend", None),
kt_num_threads=getattr(model_args, "kt_num_threads", None),
kt_tp_enabled=getattr(model_args, "kt_tp_enabled", None),
kt_threadpool_count=getattr(model_args, "kt_threadpool_count", None),
kt_max_cache_depth=getattr(model_args, "kt_max_cache_depth", None),
kt_num_gpu_experts=getattr(model_args, "kt_num_gpu_experts", None),
kt_weight_path=getattr(model_args, "kt_weight_path", None),
kt_expert_checkpoint_path=getattr(model_args, "kt_expert_checkpoint_path", None),
kt_use_lora_experts=getattr(model_args, "kt_use_lora_experts", None),
kt_lora_expert_num=getattr(model_args, "kt_lora_expert_num", None),
kt_lora_expert_intermediate_size=getattr(model_args, "kt_lora_expert_intermediate_size", None),
kt_lora_rank=getattr(finetuning_args, "lora_rank", None) if finetuning_args else None,
kt_lora_alpha=getattr(finetuning_args, "lora_alpha", None) if finetuning_args else None,
kt_model_max_length=getattr(model_args, "model_max_length", None),
)
return KTransformersPlugin(enabled=True, kt_config=kt_config)
def get_kt_loading_kwargs(
config,
kt_plugin,
torch_dtype: torch.dtype | str | None = torch.bfloat16,
trust_remote_code: bool | None = None,
token: str | None = None,
) -> dict[str, Any]:
"""Get kwargs for AutoModel.from_pretrained() for KT loading."""
kwargs: dict[str, Any] = {
"config": config,
"torch_dtype": torch_dtype,
"device_map": "cpu",
"low_cpu_mem_usage": True,
}
if trust_remote_code is not None:
kwargs["trust_remote_code"] = trust_remote_code
if token is not None:
kwargs["token"] = token
return kwargs
def _resolve_checkpoint_files(
model_name_or_path: str,
cache_dir: str | None = None,
revision: str | None = None,
token: str | None = None,
trust_remote_code: bool | None = None,
) -> tuple[list[str] | None, dict | None]:
"""Resolve HF checkpoint files. Depends on transformers internals."""
try:
from transformers.modeling_utils import _get_resolved_checkpoint_files
except Exception:
return None, None
try:
checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
pretrained_model_name_or_path=model_name_or_path,
subfolder="",
variant=None,
gguf_file=None,
from_tf=False,
from_flax=False,
use_safetensors=None,
cache_dir=cache_dir,
force_download=False,
proxies=None,
local_files_only=False,
token=token,
user_agent={"file_type": "model", "framework": "pytorch"},
revision=revision or "main",
commit_hash=None,
is_remote_code=bool(trust_remote_code),
transformers_explicit_filename=None,
)
except Exception:
return None, None
return checkpoint_files, sharded_metadata
def load_kt_model(
config,
model_args: Any | None = None,
finetuning_args: Any | None = None,
kt_plugin=None,
model_name_or_path: str | None = None,
trust_remote_code: bool | None = None,
token: str | None = None,
torch_dtype: torch.dtype | str | None = torch.bfloat16,
**kwargs,
) -> nn.Module:
"""Load model with KTMoEWrapper backend."""
from .arch import get_moe_arch_config, move_non_experts_to_gpu, get_expert_device, KTAMXNotAvailableError, KTAMXConfigError
if kt_plugin is None:
if model_args is None:
raise KTAMXConfigError("Either kt_plugin or model_args must be provided to load_kt_model().")
kt_plugin = _build_kt_plugin_from_args(model_args, finetuning_args)
if model_name_or_path is None and model_args is not None:
model_name_or_path = getattr(model_args, "model_name_or_path", None)
if model_name_or_path is None:
raise KTAMXConfigError("model_name_or_path is required to load_kt_model().")
if trust_remote_code is None and model_args is not None:
trust_remote_code = getattr(model_args, "trust_remote_code", None)
if token is None and model_args is not None:
token = getattr(model_args, "hf_hub_token", None)
cache_dir = getattr(model_args, "cache_dir", None) if model_args is not None else None
revision = getattr(model_args, "revision", None) if model_args is not None else None
_ = get_moe_arch_config(config)
logger.info("Loading model with KTMoEWrapper backend")
from transformers import AutoModelForCausalLM
from transformers.integrations.kt import set_kt_config, unset_kt_config
loading_kwargs = get_kt_loading_kwargs(
config, kt_plugin, torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code, token=token,
)
if model_args is not None:
for key in ("cache_dir", "revision"):
value = getattr(model_args, key, None)
if value is not None:
loading_kwargs[key] = value
loading_kwargs.update(kwargs)
cfg = _get_kt_config(kt_plugin)
if getattr(cfg, "kt_skip_expert_loading", None) is None:
checkpoint_files, sharded_metadata = _resolve_checkpoint_files(
model_name_or_path=model_name_or_path,
cache_dir=cache_dir, revision=revision,
token=token, trust_remote_code=trust_remote_code,
)
if checkpoint_files and all(f.endswith(".safetensors") for f in checkpoint_files):
if getattr(cfg, "kt_weight_path", None) is None:
cfg.kt_skip_expert_loading = True
else:
cfg.kt_skip_expert_loading = False
cfg.kt_checkpoint_files = checkpoint_files
cfg.kt_sharded_metadata = sharded_metadata
else:
cfg.kt_skip_expert_loading = False
set_kt_config(kt_plugin)
try:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **loading_kwargs)
finally:
unset_kt_config()
moe_config = get_moe_arch_config(config)
move_non_experts_to_gpu(model, moe_config, device="cuda:0")
existing_wrappers = getattr(model, "_kt_wrappers", None)
if existing_wrappers:
logger.info(f"MoE layers already wrapped ({len(existing_wrappers)} layers), skipping re-wrap")
wrappers = existing_wrappers
else:
wrappers = wrap_moe_layers_with_kt_wrapper(model, kt_plugin)
model._kt_wrappers = wrappers
model._kt_tp_enabled = bool(getattr(cfg, "kt_tp_enabled", False))
model._kt_use_lora_experts = bool(getattr(cfg, "kt_use_lora_experts", False))
logger.info("Model loaded with KTMoEWrapper backend successfully")
return model