mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-30 21:00:07 +00:00
feat(sft): AMX MoE SFT backend with LoRA support (#1936)
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
* feat(sft): AMX MoE SFT backend with LoRA support Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally. * refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code - KTConfig fields all use kt_ prefix matching dict keys — eliminates _OLD_TO_NEW mapping and prefix-stripping in wrapper.py - Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing is on (via training_args.py), flows through to C++ cache allocation - Remove dead checkpoint detection code: in_ckpt_recompute, in_ckpt_first_forward vars (assigned but never read), fallback _is_in_checkpoint_first_forward() function, unused inspect import - Remove redundant env var fallbacks in wrapper.py for share_backward_bb and share_cache_pool (KTConfig.__post_init__ already handles env vars) - Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check Verified: Qwen3-235B 3-step training on sap4, loss matches baseline (1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor(sft): share_backward_bb default True, share_cache_pool auto-derived - kt_share_backward_bb defaults to True (always saves memory) - kt_share_cache_pool no longer reads from env var; defaults False, auto-set to True by trainer_config_process when gradient checkpointing is enabled Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument, but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(sft): support transformers v5 fused expert format Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters (gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert nn.Linear modules. PEFT cannot attach LoRA to these, so we create KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers for the optimizer, and pre-assigned .grad for C++ backward. - arch.py: detect_fused_experts() detection - weights.py: fused format extraction and weight clearing - wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank - lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA, get_kt_lora_params collects fused params, deduplicate wrapper finding - layer.py: handle v5 TopKRouter tuple output, remove dead code - autograd.py: sync_forward_sft/submit_forward_sft API rename Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(sft): add Qwen3.5 MoE support + fused checkpoint loading - arch.py: add Qwen3_5Moe arch match, read config from text_config, _get_layers_prefix returns model.language_model.layers for Qwen3.5, _get_model_container_and_layers searches language_model attr - weights.py: load_experts_from_checkpoint_files detects fused format (gate_up_proj in weight_map) and splits into gate/up/down - wrapper.py: hidden_size fallback to text_config Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * [fix](sft): align Python API with C++ backend after v5 refactor - wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature) - layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward - autograd.py: rename sync_forward_sft to sync_forward The sft-v5 refactor (commits58d7eab,dd1da65) renamed Python-side method calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original method names. This caused AttributeError on Qwen3.5-35B and other models. * align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults - Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace, sft_timer namespace, ITT API, extended do_work_stealing_job API) - Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp, moe-sft-tp.hpp, avx_kernels.hpp) - Restore pin_memory=True in KExpertsCPUBuffer (inference path) - Restore fused tensor transpose logic in convert_cpu_weights.py (main layout) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * revert CMakeLists.txt to main: remove debug flags and cpptrace dep Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * clean up dev artifacts: remove SFT design docs, debug examples, bench scripts Remove files not needed in the merge: - docs/SFT+KTWrapper/ (6 Chinese design docs) - docs/sft_moe_amx/ (21 dev/debug docs) - 12 debug/test example scripts - 6 SFT-specific bench scripts and report Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
This commit is contained in:
parent
22e9915ec9
commit
9544a8960d
41 changed files with 22866 additions and 937 deletions
83
kt-kernel/python/sft/__init__.py
Normal file
83
kt-kernel/python/sft/__init__.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
# SFT (Supervised Fine-Tuning) submodule for kt-kernel
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
SFT training support for KT-Kernel MoE.
|
||||
|
||||
This submodule adds training capabilities (forward/backward, LoRA, autograd,
|
||||
distributed) on top of the inference-only kt_kernel base package.
|
||||
|
||||
Additional dependencies beyond base kt_kernel: torch.nn, torch.distributed, peft (optional).
|
||||
"""
|
||||
|
||||
from .config import KTConfig
|
||||
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
|
||||
from .amx import AMXSFTMoEWrapper
|
||||
from .arch import (
|
||||
MOEArchConfig, get_moe_arch_config, get_moe_module, move_non_experts_to_gpu, get_expert_device,
|
||||
KTAMXError, KTAMXNotAvailableError, KTAMXModelNotSupportedError, KTAMXConfigError,
|
||||
)
|
||||
from .autograd import KTMoEFunction
|
||||
from .layer import KTMoELayerWrapper
|
||||
from .weights import (
|
||||
extract_moe_weights,
|
||||
load_experts_from_checkpoint_files,
|
||||
load_experts_from_kt_weight_path,
|
||||
INT8ExpertWeights,
|
||||
)
|
||||
from .lora import (
|
||||
kt_adapt_peft_lora,
|
||||
get_kt_lora_params,
|
||||
update_kt_lora_pointers,
|
||||
sync_kt_lora_gradients,
|
||||
save_lora_experts_to_adapter,
|
||||
save_kt_moe_to_adapter,
|
||||
load_lora_experts_from_adapter,
|
||||
load_kt_moe_from_adapter,
|
||||
LoRAExpertMLP,
|
||||
LoRAExperts,
|
||||
)
|
||||
from .wrapper import (
|
||||
wrap_moe_layers_with_kt_wrapper,
|
||||
build_kt_device_map,
|
||||
build_kt_device_map_simplified,
|
||||
get_kt_loading_kwargs,
|
||||
load_kt_model,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KTConfig",
|
||||
"BaseSFTMoEWrapper",
|
||||
"KExpertsSFTBuffer",
|
||||
"AMXSFTMoEWrapper",
|
||||
"MOEArchConfig",
|
||||
"get_moe_arch_config",
|
||||
"get_moe_module",
|
||||
"move_non_experts_to_gpu",
|
||||
"get_expert_device",
|
||||
"KTAMXError",
|
||||
"KTAMXNotAvailableError",
|
||||
"KTAMXModelNotSupportedError",
|
||||
"KTAMXConfigError",
|
||||
"KTMoEFunction",
|
||||
"KTMoELayerWrapper",
|
||||
"extract_moe_weights",
|
||||
"load_experts_from_checkpoint_files",
|
||||
"load_experts_from_kt_weight_path",
|
||||
"INT8ExpertWeights",
|
||||
"kt_adapt_peft_lora",
|
||||
"get_kt_lora_params",
|
||||
"update_kt_lora_pointers",
|
||||
"sync_kt_lora_gradients",
|
||||
"save_lora_experts_to_adapter",
|
||||
"save_kt_moe_to_adapter",
|
||||
"load_lora_experts_from_adapter",
|
||||
"load_kt_moe_from_adapter",
|
||||
"LoRAExpertMLP",
|
||||
"LoRAExperts",
|
||||
"wrap_moe_layers_with_kt_wrapper",
|
||||
"build_kt_device_map",
|
||||
"build_kt_device_map_simplified",
|
||||
"get_kt_loading_kwargs",
|
||||
"load_kt_model",
|
||||
]
|
||||
434
kt-kernel/python/sft/amx.py
Normal file
434
kt-kernel/python/sft/amx.py
Normal file
|
|
@ -0,0 +1,434 @@
|
|||
# AMX SFT MoE Wrapper implementation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
AMX-based SFT MoE Wrapper. Forward/backward buffer management is in base class;
|
||||
this file handles weight loading, LoRA init, and C++ task construction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import os
|
||||
import glob as _glob
|
||||
import torch
|
||||
from typing import Optional, List
|
||||
|
||||
from kt_kernel_ext.moe import MOESFTConfig
|
||||
|
||||
from ..utils.loader import BF16SafeTensorLoader, SafeTensorLoader
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import (
|
||||
AMXBF16_SFT_MOE,
|
||||
AMXInt8_SFT_MOE,
|
||||
AMXInt4_SFT_MOE,
|
||||
AMXBF16_SFT_MOE_SkipLoRA,
|
||||
AMXInt8_SFT_MOE_SkipLoRA,
|
||||
AMXInt4_SFT_MOE_SkipLoRA,
|
||||
)
|
||||
|
||||
_HAS_AMX_SFT_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SFT_SUPPORT = False
|
||||
AMXBF16_SFT_MOE = None
|
||||
AMXInt8_SFT_MOE = None
|
||||
AMXInt4_SFT_MOE = None
|
||||
AMXBF16_SFT_MOE_SkipLoRA = None
|
||||
AMXInt8_SFT_MOE_SkipLoRA = None
|
||||
AMXInt4_SFT_MOE_SkipLoRA = None
|
||||
|
||||
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
|
||||
|
||||
|
||||
# Mapping from method string to C++ SFT MOE class
|
||||
_SFT_METHOD_TO_CLASS = {
|
||||
"AMXBF16_SFT": AMXBF16_SFT_MOE,
|
||||
"AMXINT8_SFT": AMXInt8_SFT_MOE,
|
||||
"AMXINT4_SFT": AMXInt4_SFT_MOE,
|
||||
"AMXBF16_SFT_SkipLoRA": AMXBF16_SFT_MOE_SkipLoRA,
|
||||
"AMXINT8_SFT_SkipLoRA": AMXInt8_SFT_MOE_SkipLoRA,
|
||||
"AMXINT4_SFT_SkipLoRA": AMXInt4_SFT_MOE_SkipLoRA,
|
||||
}
|
||||
|
||||
|
||||
class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
"""
|
||||
AMX-based SFT MoE wrapper.
|
||||
|
||||
Supports BF16, INT8, INT4, and SkipLoRA variants.
|
||||
Forward/backward buffer management is in BaseSFTMoEWrapper;
|
||||
this class implements weight loading and C++ task construction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: float = 32.0,
|
||||
max_cache_depth: int = 1,
|
||||
method: str = "AMXBF16_SFT",
|
||||
group_size: int = 128,
|
||||
zero_point: bool = True,
|
||||
):
|
||||
if not _HAS_AMX_SFT_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"AMX SFT backend not available. kt_kernel_ext was not compiled with AMX SFT support.\n"
|
||||
"Please recompile with AMX SFT enabled."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_gpu_experts=num_gpu_experts,
|
||||
cpuinfer_threads=cpuinfer_threads,
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=weight_path,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
max_cache_depth=max_cache_depth,
|
||||
)
|
||||
|
||||
self.method = method
|
||||
self._is_skip_lora = "SkipLoRA" in method
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
if method not in _SFT_METHOD_TO_CLASS:
|
||||
raise ValueError(f"Unknown SFT method: {method}. Supported: {list(_SFT_METHOD_TO_CLASS.keys())}")
|
||||
|
||||
moe_class = _SFT_METHOD_TO_CLASS[method]
|
||||
if moe_class is None:
|
||||
raise RuntimeError(f"AMX SFT method '{method}' not available in current build.")
|
||||
|
||||
self.gate_proj: Optional[torch.Tensor] = None
|
||||
self.up_proj: Optional[torch.Tensor] = None
|
||||
self.down_proj: Optional[torch.Tensor] = None
|
||||
|
||||
self._moe_class = moe_class
|
||||
|
||||
# ========== Template method: C++ task construction ==========
|
||||
|
||||
def _make_forward_task(self, buffer: KExpertsSFTBuffer, save_for_backward: bool):
|
||||
return self.moe.forward_sft_task(
|
||||
buffer.bsz_tensor.data_ptr(),
|
||||
self.num_experts_per_tok,
|
||||
buffer.expert_ids_cpu.data_ptr(),
|
||||
buffer.weights_cpu.data_ptr(),
|
||||
buffer.input_cpu.data_ptr(),
|
||||
buffer.output_cpu.data_ptr(),
|
||||
save_for_backward,
|
||||
)
|
||||
|
||||
def _make_backward_task(self, buffer: KExpertsSFTBuffer):
|
||||
if self._is_skip_lora:
|
||||
return self.moe.backward_task(
|
||||
buffer.grad_output_cpu.data_ptr(),
|
||||
buffer.grad_input_cpu.data_ptr(),
|
||||
0, 0, 0, 0, 0, 0,
|
||||
buffer.grad_weights.data_ptr(),
|
||||
)
|
||||
return self.moe.backward_task(
|
||||
buffer.grad_output_cpu.data_ptr(),
|
||||
buffer.grad_input_cpu.data_ptr(),
|
||||
self.grad_gate_lora_a.data_ptr(),
|
||||
self.grad_gate_lora_b.data_ptr(),
|
||||
self.grad_up_lora_a.data_ptr(),
|
||||
self.grad_up_lora_b.data_ptr(),
|
||||
self.grad_down_lora_a.data_ptr(),
|
||||
self.grad_down_lora_b.data_ptr(),
|
||||
buffer.grad_weights.data_ptr(),
|
||||
)
|
||||
|
||||
# ========== Weight loading ==========
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None:
|
||||
if self._weights_loaded:
|
||||
return
|
||||
|
||||
if self.gate_proj is None and not getattr(self, "_use_projs_path", False):
|
||||
self._load_base_weights_from_file()
|
||||
|
||||
config = MOESFTConfig()
|
||||
config.expert_num = self.num_experts
|
||||
config.num_experts_per_tok = self.num_experts_per_tok
|
||||
config.hidden_size = self.hidden_size
|
||||
config.intermediate_size = self.moe_intermediate_size
|
||||
config.lora_rank = self.lora_rank
|
||||
config.lora_alpha = self.lora_alpha
|
||||
config.max_cache_depth = self.max_cache_depth
|
||||
config.max_len = self.chunked_prefill_size
|
||||
config.layer_idx = self.layer_idx
|
||||
config.share_backward_bb = getattr(self, "share_backward_bb", False)
|
||||
config.share_cache_pool = getattr(self, "share_cache_pool", False)
|
||||
|
||||
if getattr(self, "_use_kt_direct_load", False):
|
||||
config.load = True
|
||||
config.path = self.weight_path
|
||||
elif getattr(self, "_use_projs_path", False):
|
||||
config.gate_projs = self._gate_projs_ptrs
|
||||
config.up_projs = self._up_projs_ptrs
|
||||
config.down_projs = self._down_projs_ptrs
|
||||
config.gate_scales = self._gate_scale_ptrs
|
||||
config.up_scales = self._up_scale_ptrs
|
||||
config.down_scales = self._down_scale_ptrs
|
||||
if getattr(self, "_bf16_gate_proj", None) is not None:
|
||||
config.gate_proj = self._bf16_gate_proj.data_ptr()
|
||||
config.up_proj = self._bf16_up_proj.data_ptr()
|
||||
config.down_proj = self._bf16_down_proj.data_ptr()
|
||||
if getattr(self, "_has_bwd_projs", False):
|
||||
config.gate_bwd_projs = self._gate_bwd_projs_ptrs
|
||||
config.up_bwd_projs = self._up_bwd_projs_ptrs
|
||||
config.down_bwd_projs = self._down_bwd_projs_ptrs
|
||||
config.gate_bwd_scales = self._gate_bwd_scale_ptrs
|
||||
config.up_bwd_scales = self._up_bwd_scale_ptrs
|
||||
config.down_bwd_scales = self._down_bwd_scale_ptrs
|
||||
else:
|
||||
config.gate_proj = self.gate_proj.data_ptr()
|
||||
config.up_proj = self.up_proj.data_ptr()
|
||||
config.down_proj = self.down_proj.data_ptr()
|
||||
|
||||
if self._lora_initialized:
|
||||
config.gate_lora_a = self.gate_lora_a.data_ptr()
|
||||
config.gate_lora_b = self.gate_lora_b.data_ptr()
|
||||
config.up_lora_a = self.up_lora_a.data_ptr()
|
||||
config.up_lora_b = self.up_lora_b.data_ptr()
|
||||
config.down_lora_a = self.down_lora_a.data_ptr()
|
||||
config.down_lora_b = self.down_lora_b.data_ptr()
|
||||
|
||||
config.pool = self.cpu_infer.backend_
|
||||
|
||||
if self.method in ("AMXINT4_KGroup_SFT", "AMXINT4_1KGroup_SFT"):
|
||||
config.quant_config.group_size = self.group_size
|
||||
config.quant_config.zero_point = self.zero_point
|
||||
|
||||
self.moe = self._moe_class(config)
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task())
|
||||
self.cpu_infer.sync()
|
||||
|
||||
self.cpu_infer.submit(self.moe.warm_up_task())
|
||||
self.cpu_infer.sync()
|
||||
|
||||
# Release Python-side weight tensors (C++ copied them)
|
||||
self.gate_proj = None
|
||||
self.up_proj = None
|
||||
self.down_proj = None
|
||||
|
||||
if getattr(self, "_bf16_gate_proj", None) is not None:
|
||||
self._bf16_gate_proj = None
|
||||
self._bf16_up_proj = None
|
||||
self._bf16_down_proj = None
|
||||
|
||||
if getattr(self, "_use_projs_path", False):
|
||||
for attr in [
|
||||
"_gate_weights_per_numa", "_up_weights_per_numa", "_down_weights_per_numa",
|
||||
"_gate_scales_per_numa", "_up_scales_per_numa", "_down_scales_per_numa",
|
||||
"_gate_projs_ptrs", "_up_projs_ptrs", "_down_projs_ptrs",
|
||||
"_gate_scale_ptrs", "_up_scale_ptrs", "_down_scale_ptrs",
|
||||
]:
|
||||
setattr(self, attr, None)
|
||||
if getattr(self, "_has_bwd_projs", False):
|
||||
for attr in [
|
||||
"_gate_bwd_weights_per_numa", "_up_bwd_weights_per_numa", "_down_bwd_weights_per_numa",
|
||||
"_gate_bwd_scales_per_numa", "_up_bwd_scales_per_numa", "_down_bwd_scales_per_numa",
|
||||
"_gate_bwd_projs_ptrs", "_up_bwd_projs_ptrs", "_down_bwd_projs_ptrs",
|
||||
"_gate_bwd_scale_ptrs", "_up_bwd_scale_ptrs", "_down_bwd_scale_ptrs",
|
||||
]:
|
||||
setattr(self, attr, None)
|
||||
|
||||
self._weights_loaded = True
|
||||
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
) -> None:
|
||||
self.gate_proj = gate_proj.contiguous()
|
||||
self.up_proj = up_proj.contiguous()
|
||||
self.down_proj = down_proj.contiguous()
|
||||
self.load_weights(physical_to_logical_map_cpu)
|
||||
del gate_proj, up_proj, down_proj
|
||||
|
||||
def _load_base_weights_from_file(self) -> None:
|
||||
if not hasattr(self, "weight_path") or self.weight_path is None:
|
||||
raise RuntimeError(
|
||||
"weight_path not set. Cannot load weights from file. "
|
||||
"Either set weight_path or call load_weights_from_tensors() instead."
|
||||
)
|
||||
|
||||
kt_layer_dir = os.path.join(self.weight_path, f"_layer_{self.layer_idx}")
|
||||
if os.path.isdir(kt_layer_dir):
|
||||
kt_files = _glob.glob(os.path.join(kt_layer_dir, "_numa_0", "*.kt"))
|
||||
if kt_files:
|
||||
self._use_kt_direct_load = True
|
||||
return
|
||||
|
||||
if "BF16" in self.method:
|
||||
loader = BF16SafeTensorLoader(self.weight_path)
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
else:
|
||||
loader = SafeTensorLoader(self.weight_path)
|
||||
base_key = f"blk.{self.layer_idx}"
|
||||
|
||||
experts_data = loader.load_experts(base_key, device="cpu")
|
||||
|
||||
gate_weights: List[torch.Tensor] = experts_data["gate"]
|
||||
up_weights: List[torch.Tensor] = experts_data["up"]
|
||||
down_weights: List[torch.Tensor] = experts_data["down"]
|
||||
|
||||
if "BF16" in self.method:
|
||||
self.gate_proj = torch.stack(gate_weights, dim=0).contiguous()
|
||||
self.up_proj = torch.stack(up_weights, dim=0).contiguous()
|
||||
self.down_proj = torch.stack(down_weights, dim=0).contiguous()
|
||||
else:
|
||||
def _make_ptrs(arrays_per_numa):
|
||||
return [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in arrays_per_numa
|
||||
]
|
||||
|
||||
self._gate_weights_per_numa = gate_weights
|
||||
self._up_weights_per_numa = up_weights
|
||||
self._down_weights_per_numa = down_weights
|
||||
self._gate_scales_per_numa = experts_data["gate_scale"]
|
||||
self._up_scales_per_numa = experts_data["up_scale"]
|
||||
self._down_scales_per_numa = experts_data["down_scale"]
|
||||
|
||||
self._gate_projs_ptrs = _make_ptrs(gate_weights)
|
||||
self._up_projs_ptrs = _make_ptrs(up_weights)
|
||||
self._down_projs_ptrs = _make_ptrs(down_weights)
|
||||
self._gate_scale_ptrs = _make_ptrs(experts_data["gate_scale"])
|
||||
self._up_scale_ptrs = _make_ptrs(experts_data["up_scale"])
|
||||
self._down_scale_ptrs = _make_ptrs(experts_data["down_scale"])
|
||||
|
||||
if "gate_bwd" in experts_data:
|
||||
self._gate_bwd_weights_per_numa = experts_data["gate_bwd"]
|
||||
self._up_bwd_weights_per_numa = experts_data["up_bwd"]
|
||||
self._down_bwd_weights_per_numa = experts_data["down_bwd"]
|
||||
self._gate_bwd_scales_per_numa = experts_data["gate_bwd_scale"]
|
||||
self._up_bwd_scales_per_numa = experts_data["up_bwd_scale"]
|
||||
self._down_bwd_scales_per_numa = experts_data["down_bwd_scale"]
|
||||
|
||||
self._gate_bwd_projs_ptrs = _make_ptrs(experts_data["gate_bwd"])
|
||||
self._up_bwd_projs_ptrs = _make_ptrs(experts_data["up_bwd"])
|
||||
self._down_bwd_projs_ptrs = _make_ptrs(experts_data["down_bwd"])
|
||||
self._gate_bwd_scale_ptrs = _make_ptrs(experts_data["gate_bwd_scale"])
|
||||
self._up_bwd_scale_ptrs = _make_ptrs(experts_data["up_bwd_scale"])
|
||||
self._down_bwd_scale_ptrs = _make_ptrs(experts_data["down_bwd_scale"])
|
||||
self._has_bwd_projs = True
|
||||
else:
|
||||
self._has_bwd_projs = False
|
||||
|
||||
self.gate_proj = None
|
||||
self.up_proj = None
|
||||
self.down_proj = None
|
||||
self._use_projs_path = True
|
||||
|
||||
loader.close_all_handles()
|
||||
|
||||
# ========== LoRA ==========
|
||||
|
||||
def init_lora_weights(
|
||||
self,
|
||||
gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor,
|
||||
up_lora_a: torch.Tensor, up_lora_b: torch.Tensor,
|
||||
down_lora_a: torch.Tensor, down_lora_b: torch.Tensor,
|
||||
grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor,
|
||||
grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor,
|
||||
grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor,
|
||||
) -> None:
|
||||
expected_shapes = {
|
||||
"gate_lora_a": (self.num_experts, self.lora_rank, self.hidden_size),
|
||||
"gate_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank),
|
||||
"up_lora_a": (self.num_experts, self.lora_rank, self.hidden_size),
|
||||
"up_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank),
|
||||
"down_lora_a": (self.num_experts, self.lora_rank, self.moe_intermediate_size),
|
||||
"down_lora_b": (self.num_experts, self.hidden_size, self.lora_rank),
|
||||
}
|
||||
provided = {
|
||||
"gate_lora_a": gate_lora_a, "gate_lora_b": gate_lora_b,
|
||||
"up_lora_a": up_lora_a, "up_lora_b": up_lora_b,
|
||||
"down_lora_a": down_lora_a, "down_lora_b": down_lora_b,
|
||||
}
|
||||
for name, tensor in provided.items():
|
||||
expected = expected_shapes[name]
|
||||
if tensor.shape != expected:
|
||||
raise ValueError(f"{name} shape mismatch: expected {expected}, got {tuple(tensor.shape)}")
|
||||
|
||||
self.gate_lora_a = gate_lora_a.contiguous()
|
||||
self.gate_lora_b = gate_lora_b.contiguous()
|
||||
self.up_lora_a = up_lora_a.contiguous()
|
||||
self.up_lora_b = up_lora_b.contiguous()
|
||||
self.down_lora_a = down_lora_a.contiguous()
|
||||
self.down_lora_b = down_lora_b.contiguous()
|
||||
|
||||
self.grad_gate_lora_a = grad_gate_lora_a.contiguous()
|
||||
self.grad_gate_lora_b = grad_gate_lora_b.contiguous()
|
||||
self.grad_up_lora_a = grad_up_lora_a.contiguous()
|
||||
self.grad_up_lora_b = grad_up_lora_b.contiguous()
|
||||
self.grad_down_lora_a = grad_down_lora_a.contiguous()
|
||||
self.grad_down_lora_b = grad_down_lora_b.contiguous()
|
||||
|
||||
self._lora_initialized = True
|
||||
|
||||
if self._weights_loaded and self.moe is not None:
|
||||
self.update_lora_weights()
|
||||
|
||||
def update_lora_weights(self) -> None:
|
||||
if not self._weights_loaded:
|
||||
raise RuntimeError("Weights not loaded. Call load_weights() first.")
|
||||
if self._is_skip_lora:
|
||||
return
|
||||
if not self._lora_initialized:
|
||||
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
|
||||
|
||||
self.cpu_infer.submit(
|
||||
self.moe.update_lora_weights_task(
|
||||
self.gate_lora_a.data_ptr(),
|
||||
self.gate_lora_b.data_ptr(),
|
||||
self.up_lora_a.data_ptr(),
|
||||
self.up_lora_b.data_ptr(),
|
||||
self.down_lora_a.data_ptr(),
|
||||
self.down_lora_b.data_ptr(),
|
||||
)
|
||||
)
|
||||
self.cpu_infer.sync()
|
||||
|
||||
def save_backward_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map: torch.Tensor,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
if not self._weights_loaded:
|
||||
raise RuntimeError("Weights not loaded. Call load_weights() first.")
|
||||
gate_proj = gate_proj.contiguous()
|
||||
up_proj = up_proj.contiguous()
|
||||
down_proj = down_proj.contiguous()
|
||||
self.moe.prepare_and_save_bwd(
|
||||
gate_proj.data_ptr(),
|
||||
up_proj.data_ptr(),
|
||||
down_proj.data_ptr(),
|
||||
output_path,
|
||||
)
|
||||
282
kt-kernel/python/sft/arch.py
Normal file
282
kt-kernel/python/sft/arch.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
# MoE architecture configuration and model utilities
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
MoE architecture detection and model navigation utilities.
|
||||
|
||||
This is a leaf module — no imports from other sft/ submodules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exceptions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KTAMXError(Exception):
|
||||
"""Base exception for KT AMX errors."""
|
||||
|
||||
|
||||
class KTAMXNotAvailableError(KTAMXError):
|
||||
"""kt_kernel not installed or AMX not supported."""
|
||||
|
||||
|
||||
class KTAMXModelNotSupportedError(KTAMXError):
|
||||
"""Model architecture not supported."""
|
||||
|
||||
|
||||
class KTAMXConfigError(KTAMXError):
|
||||
"""Configuration error."""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MoE Configuration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class MOEArchConfig:
|
||||
"""MoE architecture configuration for different model types."""
|
||||
|
||||
moe_layer_attr: str
|
||||
router_attr: str
|
||||
experts_attr: str
|
||||
weight_names: tuple[str, str, str]
|
||||
expert_num: int
|
||||
intermediate_size: int
|
||||
num_experts_per_tok: int
|
||||
has_shared_experts: bool = False
|
||||
router_type: str = "linear"
|
||||
|
||||
|
||||
def get_moe_arch_config(config) -> MOEArchConfig:
|
||||
"""
|
||||
Get MoE architecture configuration based on model type.
|
||||
|
||||
Args:
|
||||
config: HuggingFace model configuration
|
||||
|
||||
Returns:
|
||||
MOEArchConfig for the model
|
||||
|
||||
Raises:
|
||||
KTAMXModelNotSupportedError: If model architecture is not supported
|
||||
"""
|
||||
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
|
||||
|
||||
if "DeepseekV2" in arch:
|
||||
return MOEArchConfig(
|
||||
moe_layer_attr="mlp",
|
||||
router_attr="gate",
|
||||
experts_attr="experts",
|
||||
weight_names=("gate_proj", "up_proj", "down_proj"),
|
||||
expert_num=config.n_routed_experts,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
num_experts_per_tok=config.num_experts_per_tok,
|
||||
has_shared_experts=getattr(config, "n_shared_experts", 0) > 0,
|
||||
router_type="deepseek_gate",
|
||||
)
|
||||
if "DeepseekV3" in arch:
|
||||
return MOEArchConfig(
|
||||
moe_layer_attr="mlp",
|
||||
router_attr="gate",
|
||||
experts_attr="experts",
|
||||
weight_names=("gate_proj", "up_proj", "down_proj"),
|
||||
expert_num=config.n_routed_experts,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
num_experts_per_tok=config.num_experts_per_tok,
|
||||
has_shared_experts=getattr(config, "n_shared_experts", 0) > 0,
|
||||
router_type="deepseek_gate",
|
||||
)
|
||||
if "Qwen2Moe" in arch or "Qwen3Moe" in arch or "Qwen3_5Moe" in arch:
|
||||
cfg = getattr(config, "text_config", config)
|
||||
return MOEArchConfig(
|
||||
moe_layer_attr="mlp",
|
||||
router_attr="gate",
|
||||
experts_attr="experts",
|
||||
weight_names=("gate_proj", "up_proj", "down_proj"),
|
||||
expert_num=cfg.num_experts,
|
||||
intermediate_size=cfg.moe_intermediate_size,
|
||||
num_experts_per_tok=cfg.num_experts_per_tok,
|
||||
has_shared_experts=getattr(cfg, "shared_expert_intermediate_size", 0) > 0,
|
||||
)
|
||||
if "Mixtral" in arch:
|
||||
return MOEArchConfig(
|
||||
moe_layer_attr="block_sparse_moe",
|
||||
router_attr="gate",
|
||||
experts_attr="experts",
|
||||
weight_names=("w1", "w3", "w2"),
|
||||
expert_num=config.num_local_experts,
|
||||
intermediate_size=config.intermediate_size,
|
||||
num_experts_per_tok=config.num_experts_per_tok,
|
||||
has_shared_experts=False,
|
||||
)
|
||||
|
||||
raise KTAMXModelNotSupportedError(
|
||||
f"Model architecture {arch} not supported for KT AMX. "
|
||||
"Supported architectures: DeepseekV2, DeepseekV3, Qwen2Moe, Qwen3Moe, Qwen3_5Moe, Mixtral"
|
||||
)
|
||||
|
||||
|
||||
def get_moe_module(layer: nn.Module, moe_config: MOEArchConfig) -> nn.Module | None:
|
||||
"""Get MoE module from transformer layer."""
|
||||
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
|
||||
if moe_module is None:
|
||||
return None
|
||||
if not hasattr(moe_module, moe_config.experts_attr):
|
||||
return None
|
||||
return moe_module
|
||||
|
||||
|
||||
def detect_fused_experts(experts: nn.Module) -> bool:
|
||||
"""Detect if experts module uses the transformers v5 fused format.
|
||||
|
||||
Fused format: a single Module with ``gate_up_proj`` [E, 2I, H] and
|
||||
``down_proj`` [E, H, I] 3-D tensors instead of a ModuleList of Linear experts.
|
||||
"""
|
||||
if experts is None:
|
||||
return False
|
||||
gate_up = getattr(experts, "gate_up_proj", None)
|
||||
down = getattr(experts, "down_proj", None)
|
||||
if isinstance(gate_up, torch.Tensor) and isinstance(down, torch.Tensor):
|
||||
return gate_up.dim() == 3 and down.dim() == 3
|
||||
return False
|
||||
|
||||
|
||||
def _get_layers_prefix(config) -> str:
|
||||
arch = config.architectures[0] if getattr(config, "architectures", None) else ""
|
||||
if "Qwen3_5Moe" in arch:
|
||||
return "model.language_model.layers"
|
||||
return "model.layers"
|
||||
|
||||
|
||||
def _get_model_container_and_layers(model: nn.Module, *, purpose: str) -> tuple[nn.Module, any]:
|
||||
"""
|
||||
Resolve the transformer layer container for KT integration.
|
||||
|
||||
KT expects the transformer block stack to be accessible as `<container>.layers`.
|
||||
Handles PEFT PeftModel, TRL value-head models, DDP wrappers.
|
||||
"""
|
||||
to_visit: list[nn.Module] = [model]
|
||||
visited: set[int] = set()
|
||||
visited_types: list[str] = []
|
||||
|
||||
while to_visit:
|
||||
current = to_visit.pop(0)
|
||||
if id(current) in visited:
|
||||
continue
|
||||
visited.add(id(current))
|
||||
visited_types.append(type(current).__name__)
|
||||
|
||||
layers = getattr(current, "layers", None)
|
||||
if layers is not None and isinstance(layers, (list, tuple, nn.ModuleList)):
|
||||
return current, layers
|
||||
|
||||
for attr in ("model", "base_model", "pretrained_model", "module", "language_model"):
|
||||
child = getattr(current, attr, None)
|
||||
if isinstance(child, nn.Module) and child is not current:
|
||||
to_visit.append(child)
|
||||
|
||||
get_base_model = getattr(current, "get_base_model", None)
|
||||
if callable(get_base_model):
|
||||
try:
|
||||
base = get_base_model()
|
||||
except Exception:
|
||||
base = None
|
||||
if isinstance(base, nn.Module) and base is not current:
|
||||
to_visit.append(base)
|
||||
|
||||
visited_preview = ", ".join(visited_types[:6])
|
||||
if len(visited_types) > 6:
|
||||
visited_preview += ", ..."
|
||||
|
||||
raise KTAMXConfigError(
|
||||
f"Model does not expose a .model.layers or .layers attribute for KT {purpose}. "
|
||||
"Tried unwrapping via model/base_model/pretrained_model/module/get_base_model; "
|
||||
f"visited: {visited_preview}"
|
||||
)
|
||||
|
||||
|
||||
def move_non_experts_to_gpu(
|
||||
model: nn.Module,
|
||||
moe_config: MOEArchConfig | None = None,
|
||||
device: str = "cuda:0",
|
||||
) -> None:
|
||||
"""Move non-expert parameters to GPU after loading (experts stay on CPU)."""
|
||||
if moe_config is None:
|
||||
config = getattr(model, "config", None)
|
||||
if config is None:
|
||||
raise KTAMXConfigError("Model config is required to infer MoE architecture.")
|
||||
moe_config = get_moe_arch_config(config)
|
||||
|
||||
container, layers = _get_model_container_and_layers(model, purpose="placement")
|
||||
|
||||
if hasattr(container, "embed_tokens"):
|
||||
container.embed_tokens.to(device)
|
||||
if hasattr(container, "norm"):
|
||||
container.norm.to(device)
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head.to(device)
|
||||
|
||||
for layer in layers:
|
||||
if hasattr(layer, "self_attn"):
|
||||
layer.self_attn.to(device)
|
||||
|
||||
if hasattr(layer, "input_layernorm"):
|
||||
layer.input_layernorm.to(device)
|
||||
if hasattr(layer, "post_attention_layernorm"):
|
||||
layer.post_attention_layernorm.to(device)
|
||||
|
||||
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
|
||||
if moe_module is None or not hasattr(moe_module, moe_config.experts_attr):
|
||||
if hasattr(layer, "mlp"):
|
||||
layer.mlp.to(device)
|
||||
continue
|
||||
|
||||
router = getattr(moe_module, moe_config.router_attr, None)
|
||||
if router is not None:
|
||||
router.to(device)
|
||||
|
||||
if hasattr(moe_module, "shared_experts") and moe_module.shared_experts is not None:
|
||||
moe_module.shared_experts.to(device)
|
||||
|
||||
logger.info(f"Moved non-expert parameters to {device}")
|
||||
|
||||
|
||||
def get_expert_device(model: nn.Module, moe_config: MOEArchConfig | None = None) -> str:
|
||||
"""Get the device type of MoE experts."""
|
||||
if moe_config is None:
|
||||
config = getattr(model, "config", None)
|
||||
if config is None:
|
||||
return "unknown"
|
||||
moe_config = get_moe_arch_config(config)
|
||||
|
||||
try:
|
||||
_, layers = _get_model_container_and_layers(model, purpose="expert device probing")
|
||||
except KTAMXConfigError:
|
||||
return "unknown"
|
||||
|
||||
for layer in layers:
|
||||
moe_module = getattr(layer, moe_config.moe_layer_attr, None)
|
||||
if moe_module is None:
|
||||
continue
|
||||
experts = getattr(moe_module, moe_config.experts_attr, None)
|
||||
if not experts:
|
||||
continue
|
||||
first_expert = experts[0]
|
||||
gate_name = moe_config.weight_names[0]
|
||||
gate_proj = getattr(first_expert, gate_name, None)
|
||||
if gate_proj is not None:
|
||||
return str(gate_proj.weight.device.type)
|
||||
|
||||
return "unknown"
|
||||
254
kt-kernel/python/sft/autograd.py
Normal file
254
kt-kernel/python/sft/autograd.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
# Autograd function for KT MoE SFT training
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from .dist_utils import (
|
||||
_all_gather_qlens,
|
||||
_qlen_offsets,
|
||||
_dist_gather_varlen_to_rank0,
|
||||
_dist_scatter_varlen_from_rank0,
|
||||
)
|
||||
|
||||
_KT_SFT_DEBUG = os.environ.get("KT_SFT_DEBUG", "0") == "1"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KTMoEFunction(torch.autograd.Function):
|
||||
"""Unified autograd function for KTMoE forward/backward."""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
wrapper: Any,
|
||||
lora_ref: torch.Tensor,
|
||||
hidden_size: int,
|
||||
num_experts_per_tok: int,
|
||||
layer_idx: int,
|
||||
training: bool,
|
||||
train_lora: bool,
|
||||
all_qlens: list[int] | tuple[int, ...] | None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if _KT_SFT_DEBUG:
|
||||
logging.debug(
|
||||
"KTMoEFunction.forward: layer=%d training=%s train_lora=%s",
|
||||
layer_idx, training, train_lora,
|
||||
)
|
||||
|
||||
original_device = hidden_states.device
|
||||
original_dtype = hidden_states.dtype
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
qlen = batch_size * seq_len
|
||||
|
||||
import torch.distributed as dist
|
||||
dist_on = dist.is_initialized() and dist.get_world_size() > 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist_on else 1
|
||||
|
||||
ctx.use_broadcast = wrapper is None
|
||||
|
||||
# ---- Sync CPU expert result and distribute ----
|
||||
if dist_on:
|
||||
if all_qlens is None:
|
||||
all_qlens_list = _all_gather_qlens(qlen, original_device, world_size)
|
||||
else:
|
||||
all_qlens_list = [int(q) for q in all_qlens]
|
||||
if len(all_qlens_list) != world_size:
|
||||
raise RuntimeError(
|
||||
f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}"
|
||||
)
|
||||
if int(all_qlens_list[rank]) != qlen:
|
||||
raise RuntimeError(
|
||||
f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}"
|
||||
)
|
||||
total_qlen = sum(all_qlens_list)
|
||||
|
||||
# Rank 0: sync CPU result and split by real lengths
|
||||
if rank == 0:
|
||||
cpu_output = wrapper.sync_forward(output_device=original_device)
|
||||
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, hidden_size)
|
||||
offsets = _qlen_offsets(all_qlens_list)
|
||||
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
|
||||
else:
|
||||
scatter_list = None
|
||||
|
||||
output_flat = _dist_scatter_varlen_from_rank0(
|
||||
rank0_chunks=scatter_list,
|
||||
all_qlens=all_qlens_list,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
feature_shape=(hidden_size,),
|
||||
device=original_device,
|
||||
dtype=original_dtype,
|
||||
)
|
||||
output = output_flat.view(batch_size, seq_len, hidden_size)
|
||||
del output_flat
|
||||
elif wrapper is not None:
|
||||
# Single-GPU: sync directly
|
||||
cpu_output = wrapper.sync_forward(output_device=original_device)
|
||||
output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype)
|
||||
else:
|
||||
# Broadcast-only rank (no wrapper)
|
||||
output = torch.empty(
|
||||
batch_size, seq_len, hidden_size, device=original_device, dtype=original_dtype
|
||||
)
|
||||
|
||||
ctx.wrapper = wrapper
|
||||
ctx.hidden_size = hidden_size
|
||||
ctx.qlen = qlen
|
||||
ctx.batch_size = batch_size
|
||||
ctx.seq_len = seq_len
|
||||
ctx.original_device = original_device
|
||||
ctx.original_dtype = original_dtype
|
||||
ctx.weights_shape = topk_weights.shape
|
||||
ctx.weights_dtype = topk_weights.dtype
|
||||
ctx.weights_device = topk_weights.device
|
||||
ctx.dist_on = dist_on
|
||||
ctx.world_size = world_size
|
||||
ctx.all_qlens = all_qlens_list if dist_on else None
|
||||
ctx.num_experts_per_tok = num_experts_per_tok
|
||||
ctx.layer_idx = layer_idx
|
||||
|
||||
# Save a sentinel tensor so non-reentrant checkpoint's saved_tensors
|
||||
# hooks can intercept it. When backward accesses ctx.saved_tensors,
|
||||
# the checkpoint unpack hook triggers a full recompute of the decoder
|
||||
# layer — which re-runs the MoE forward with save_for_backward=True,
|
||||
# populating the C++ cache BEFORE this backward proceeds.
|
||||
# Without this, MoE backward runs before the recompute (MoE comes
|
||||
# after attention in forward order → its backward runs first), and
|
||||
# the C++ cache is empty when first-forward cache-skip is active.
|
||||
ctx.save_for_backward(hidden_states.new_empty(()))
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor):
|
||||
# Wait for any in-flight async repack before recompute forward uses the pool
|
||||
if getattr(ctx.wrapper, 'share_backward_bb', False):
|
||||
ctx.wrapper.wait_backward_repack()
|
||||
|
||||
# Access saved_tensors FIRST — under non-reentrant checkpoint this
|
||||
# triggers the unpack hook which runs a full decoder-layer recompute,
|
||||
# populating the C++ cache before we call wrapper.backward().
|
||||
_ = ctx.saved_tensors
|
||||
|
||||
qlen = ctx.qlen
|
||||
hidden_size = ctx.hidden_size
|
||||
batch_size = ctx.batch_size
|
||||
seq_len = ctx.seq_len
|
||||
dist_on = ctx.dist_on
|
||||
world_size = ctx.world_size
|
||||
num_experts_per_tok = ctx.num_experts_per_tok
|
||||
|
||||
import torch.distributed as dist
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
if _KT_SFT_DEBUG:
|
||||
logging.debug(
|
||||
"KTMoEFunction.backward: layer=%d dist_on=%s qlen=%d",
|
||||
getattr(ctx, "layer_idx", -1), dist_on, qlen,
|
||||
)
|
||||
|
||||
if dist_on:
|
||||
all_qlens = getattr(ctx, "all_qlens", None)
|
||||
if all_qlens is None or len(all_qlens) != world_size:
|
||||
all_qlens = _all_gather_qlens(qlen, ctx.original_device, world_size)
|
||||
else:
|
||||
all_qlens = [int(q) for q in all_qlens]
|
||||
if int(all_qlens[rank]) != qlen:
|
||||
raise RuntimeError(
|
||||
f"Backward qlen mismatch on rank {rank}: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}"
|
||||
)
|
||||
|
||||
grad_out_flat = grad_output.view(qlen, hidden_size).contiguous()
|
||||
|
||||
gathered_go = _dist_gather_varlen_to_rank0(
|
||||
grad_out_flat,
|
||||
all_qlens=all_qlens,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
if rank == 0:
|
||||
all_go = torch.cat(gathered_go, dim=0)
|
||||
total_qlen = int(all_go.shape[0])
|
||||
|
||||
backward_out = ctx.wrapper.backward(
|
||||
all_go,
|
||||
output_device=ctx.original_device,
|
||||
)
|
||||
if isinstance(backward_out, tuple) and len(backward_out) == 2:
|
||||
all_grad_input, all_grad_weights = backward_out
|
||||
elif isinstance(backward_out, tuple) and len(backward_out) == 3:
|
||||
all_grad_input, _, all_grad_weights = backward_out
|
||||
else:
|
||||
raise ValueError("KTMoEWrapper.backward returned unexpected format.")
|
||||
|
||||
all_grad_input = all_grad_input.to(dtype=ctx.original_dtype).view(total_qlen, hidden_size)
|
||||
all_grad_weights = all_grad_weights.to(dtype=torch.bfloat16).view(total_qlen, num_experts_per_tok)
|
||||
|
||||
offsets = _qlen_offsets(all_qlens)
|
||||
scatter_gi = [all_grad_input[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
|
||||
scatter_gw = [all_grad_weights[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
|
||||
else:
|
||||
scatter_gi = None
|
||||
scatter_gw = None
|
||||
|
||||
grad_input_flat = _dist_scatter_varlen_from_rank0(
|
||||
rank0_chunks=scatter_gi,
|
||||
all_qlens=all_qlens,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
feature_shape=(hidden_size,),
|
||||
device=ctx.original_device,
|
||||
dtype=ctx.original_dtype,
|
||||
)
|
||||
grad_weights_flat = _dist_scatter_varlen_from_rank0(
|
||||
rank0_chunks=scatter_gw,
|
||||
all_qlens=all_qlens,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
feature_shape=(num_experts_per_tok,),
|
||||
device=ctx.weights_device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
grad_input = grad_input_flat.view(batch_size, seq_len, hidden_size)
|
||||
grad_weights = grad_weights_flat.view(ctx.weights_shape).to(dtype=ctx.weights_dtype)
|
||||
|
||||
elif not ctx.use_broadcast:
|
||||
# ---- Single-GPU path ----
|
||||
grad_output_flat = grad_output.view(qlen, hidden_size)
|
||||
backward_out = ctx.wrapper.backward(
|
||||
grad_output_flat,
|
||||
output_device=ctx.original_device,
|
||||
)
|
||||
ctx.wrapper._kt_has_cached_forward = False
|
||||
if isinstance(backward_out, tuple) and len(backward_out) == 2:
|
||||
grad_input, grad_weights = backward_out
|
||||
elif isinstance(backward_out, tuple) and len(backward_out) == 3:
|
||||
grad_input, _, grad_weights = backward_out
|
||||
else:
|
||||
raise ValueError("KTMoEWrapper.backward returned unexpected format.")
|
||||
grad_input = grad_input.view(batch_size, seq_len, hidden_size).to(dtype=ctx.original_dtype)
|
||||
grad_weights = grad_weights.to(dtype=torch.bfloat16)
|
||||
else:
|
||||
# No wrapper, no dist — shouldn't happen in normal flow
|
||||
grad_input = torch.zeros(batch_size, seq_len, hidden_size, device=ctx.original_device, dtype=ctx.original_dtype)
|
||||
grad_weights = torch.zeros(ctx.weights_shape, device=ctx.weights_device, dtype=ctx.weights_dtype)
|
||||
|
||||
# Trigger async repack for next MoE layer in backward order
|
||||
next_bwd = getattr(ctx.wrapper, '_next_backward_wrapper', None)
|
||||
if next_bwd is not None and getattr(next_bwd, 'share_backward_bb', False):
|
||||
next_bwd.submit_backward_repack()
|
||||
|
||||
return grad_input, None, grad_weights, None, None, None, None, None, None, None, None
|
||||
402
kt-kernel/python/sft/base.py
Normal file
402
kt-kernel/python/sft/base.py
Normal file
|
|
@ -0,0 +1,402 @@
|
|||
# Base classes for SFT MoE operations
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
SFT (Supervised Fine-Tuning) MoE base classes and buffer management.
|
||||
|
||||
Provides:
|
||||
- KExpertsSFTBuffer: Grow-only shared buffer for forward/backward passes
|
||||
- BaseSFTMoEWrapper: Abstract base with concrete buffer management (template method pattern)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..experts_base import _MoEBase
|
||||
|
||||
|
||||
class KExpertsSFTBuffer:
|
||||
"""
|
||||
CPU buffer management for SFT expert computation.
|
||||
|
||||
Single grow-only buffer (never shrinks). Callers must use [:qlen] slicing
|
||||
since the buffer may be larger than the current batch.
|
||||
"""
|
||||
|
||||
_shared_buffer: Optional["KExpertsSFTBuffer"] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qlen: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
lora_rank: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
self.qlen = qlen
|
||||
self.hidden_size = hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts = num_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.lora_rank = lora_rank
|
||||
self.dtype = dtype
|
||||
|
||||
pin_memory = False
|
||||
|
||||
# Forward buffers
|
||||
self.input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
||||
self.expert_ids_cpu = torch.empty(
|
||||
(qlen, num_experts_per_tok), dtype=torch.int64, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.weights_cpu = torch.empty(
|
||||
(qlen, num_experts_per_tok), dtype=torch.float32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
||||
|
||||
# Backward buffers
|
||||
self.grad_output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
||||
self.grad_input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
||||
self.grad_weights = torch.empty((qlen, num_experts_per_tok), dtype=torch.float32, device="cpu")
|
||||
|
||||
# Batch size tensor for C++ interface
|
||||
self.bsz_tensor = torch.tensor([qlen], dtype=torch.int32, device="cpu")
|
||||
|
||||
@classmethod
|
||||
def get_buffer(
|
||||
cls,
|
||||
qlen: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
lora_rank: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> "KExpertsSFTBuffer":
|
||||
"""Get or grow the single shared buffer. Only reallocates when qlen exceeds capacity."""
|
||||
buf = cls._shared_buffer
|
||||
if buf is not None and qlen <= buf.qlen:
|
||||
return buf
|
||||
cls._shared_buffer = cls(
|
||||
qlen=qlen,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
lora_rank=lora_rank,
|
||||
dtype=dtype,
|
||||
)
|
||||
return cls._shared_buffer
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the shared buffer."""
|
||||
cls._shared_buffer = None
|
||||
|
||||
|
||||
class BaseSFTMoEWrapper(_MoEBase, ABC):
|
||||
"""
|
||||
Base class for SFT MoE CPU operations with concrete buffer management.
|
||||
|
||||
Subclasses implement:
|
||||
- _make_forward_task(buffer, save_for_backward) -> C++ task object
|
||||
- _make_backward_task(buffer) -> C++ task object
|
||||
- load_weights(physical_to_logical_map_cpu)
|
||||
- init_lora_weights(...)
|
||||
- update_lora_weights()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: float = 32.0,
|
||||
max_cache_depth: int = 1,
|
||||
):
|
||||
self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count)
|
||||
|
||||
self._validate_base_config(
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
)
|
||||
self._validate_sft_config(lora_rank, lora_alpha, max_cache_depth)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = num_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.hidden_size = hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_gpu_experts = num_gpu_experts
|
||||
self.weight_path = weight_path
|
||||
self.chunked_prefill_size = chunked_prefill_size
|
||||
self.threadpool_count = threadpool_count
|
||||
|
||||
self.lora_rank = lora_rank
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_scaling = lora_alpha / lora_rank
|
||||
self.max_cache_depth = max_cache_depth
|
||||
|
||||
self.gate_lora_a: Optional[torch.Tensor] = None
|
||||
self.gate_lora_b: Optional[torch.Tensor] = None
|
||||
self.up_lora_a: Optional[torch.Tensor] = None
|
||||
self.up_lora_b: Optional[torch.Tensor] = None
|
||||
self.down_lora_a: Optional[torch.Tensor] = None
|
||||
self.down_lora_b: Optional[torch.Tensor] = None
|
||||
|
||||
self._weights_loaded: bool = False
|
||||
self._lora_initialized: bool = False
|
||||
self._cache_depth: int = 0
|
||||
self._is_skip_lora: bool = False
|
||||
|
||||
self.moe = None
|
||||
|
||||
@staticmethod
|
||||
def _validate_sft_config(lora_rank: int, lora_alpha: float, max_cache_depth: int) -> None:
|
||||
if lora_rank <= 0:
|
||||
raise ValueError(f"lora_rank must be positive, got {lora_rank}")
|
||||
if lora_alpha <= 0:
|
||||
raise ValueError(f"lora_alpha must be positive, got {lora_alpha}")
|
||||
if max_cache_depth <= 0:
|
||||
raise ValueError(f"max_cache_depth must be positive, got {max_cache_depth}")
|
||||
|
||||
# ========== Abstract methods for subclasses ==========
|
||||
|
||||
@abstractmethod
|
||||
def _make_forward_task(self, buffer: KExpertsSFTBuffer, save_for_backward: bool):
|
||||
"""Construct the C++ forward task object. Backend-specific."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _make_backward_task(self, buffer: KExpertsSFTBuffer):
|
||||
"""Construct the C++ backward task object. Backend-specific."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def init_lora_weights(
|
||||
self,
|
||||
gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor,
|
||||
up_lora_a: torch.Tensor, up_lora_b: torch.Tensor,
|
||||
down_lora_a: torch.Tensor, down_lora_b: torch.Tensor,
|
||||
grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor,
|
||||
grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor,
|
||||
grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update_lora_weights(self) -> None:
|
||||
...
|
||||
|
||||
# ========== Buffer helpers ==========
|
||||
|
||||
def _get_buffer(self, qlen: int) -> KExpertsSFTBuffer:
|
||||
return KExpertsSFTBuffer.get_buffer(
|
||||
qlen=qlen,
|
||||
hidden_size=self.hidden_size,
|
||||
moe_intermediate_size=self.moe_intermediate_size,
|
||||
num_experts=self.num_experts,
|
||||
num_experts_per_tok=self.num_experts_per_tok,
|
||||
lora_rank=self.lora_rank,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
def _validate_forward_inputs(self, hidden_states: torch.Tensor, expert_ids: torch.Tensor, weights: torch.Tensor):
|
||||
if not self._weights_loaded:
|
||||
raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.")
|
||||
if not self._lora_initialized and not self._is_skip_lora:
|
||||
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
|
||||
qlen = hidden_states.shape[0]
|
||||
if qlen > self.chunked_prefill_size:
|
||||
raise ValueError(
|
||||
f"qlen ({qlen}) exceeds chunked_prefill_size ({self.chunked_prefill_size}). "
|
||||
"Increase chunked_prefill_size or reduce qlen to avoid buffer overrun."
|
||||
)
|
||||
if expert_ids.shape[0] != qlen or expert_ids.shape[1] != self.num_experts_per_tok:
|
||||
raise ValueError(
|
||||
f"expert_ids shape {tuple(expert_ids.shape)} must be ({qlen}, {self.num_experts_per_tok})."
|
||||
)
|
||||
if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok:
|
||||
raise ValueError(
|
||||
f"weights shape {tuple(weights.shape)} must be ({qlen}, {self.num_experts_per_tok})."
|
||||
)
|
||||
|
||||
def _copy_inputs_to_buffer(self, buffer: KExpertsSFTBuffer, hidden_states: torch.Tensor,
|
||||
expert_ids: torch.Tensor, weights: torch.Tensor, qlen: int) -> torch.device:
|
||||
"""Copy inputs to CPU buffer, return input device."""
|
||||
input_device = hidden_states.device
|
||||
buffer.input_cpu[:qlen].copy_(hidden_states.to(torch.bfloat16), non_blocking=True)
|
||||
buffer.expert_ids_cpu[:qlen].copy_(expert_ids.to(torch.int64), non_blocking=True)
|
||||
buffer.weights_cpu[:qlen].copy_(weights.to(torch.float32), non_blocking=True)
|
||||
buffer.bsz_tensor[0] = qlen
|
||||
if input_device.type == "cuda":
|
||||
torch.cuda.synchronize(input_device)
|
||||
return input_device
|
||||
|
||||
def _copy_grad_output_to_cpu(self, buffer: KExpertsSFTBuffer, grad_output: torch.Tensor, qlen: int):
|
||||
"""Copy grad_output to CPU buffer."""
|
||||
input_device = grad_output.device
|
||||
if input_device.type == "cuda":
|
||||
torch.cuda.synchronize(input_device)
|
||||
buffer.grad_output_cpu[:qlen].copy_(grad_output.to(torch.bfloat16))
|
||||
|
||||
def _return_output(self, buffer: KExpertsSFTBuffer, qlen: int, output_device: Optional[torch.device]):
|
||||
if output_device is not None:
|
||||
return buffer.output_cpu[:qlen].to(device=output_device, non_blocking=True)
|
||||
else:
|
||||
return buffer.output_cpu[:qlen].clone()
|
||||
|
||||
def _return_grads(self, buffer: KExpertsSFTBuffer, qlen: int, output_device: Optional[torch.device]):
|
||||
if output_device is not None:
|
||||
grad_input = buffer.grad_input_cpu[:qlen].to(device=output_device, non_blocking=True)
|
||||
grad_weights = buffer.grad_weights[:qlen].to(device=output_device, non_blocking=True)
|
||||
else:
|
||||
grad_input = buffer.grad_input_cpu[:qlen].clone()
|
||||
grad_weights = buffer.grad_weights[:qlen].clone()
|
||||
return grad_input, grad_weights
|
||||
|
||||
# ========== Concrete forward/backward ==========
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
save_for_backward: bool = True,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Synchronous forward pass with optional gradient caching."""
|
||||
self._validate_forward_inputs(hidden_states, expert_ids, weights)
|
||||
qlen = hidden_states.shape[0]
|
||||
buffer = self._get_buffer(qlen)
|
||||
self._copy_inputs_to_buffer(buffer, hidden_states, expert_ids, weights, qlen)
|
||||
|
||||
self.cpu_infer.submit(self._make_forward_task(buffer, save_for_backward))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
if save_for_backward and self._cache_depth == 0:
|
||||
self._cache_depth += 1
|
||||
|
||||
return self._return_output(buffer, qlen, output_device)
|
||||
|
||||
def backward(
|
||||
self,
|
||||
grad_output: torch.Tensor,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Backward pass computing grad_input and grad_weights."""
|
||||
if self._cache_depth <= 0:
|
||||
raise RuntimeError("No forward cache available. Call forward(save_for_backward=True) first.")
|
||||
|
||||
qlen = grad_output.shape[0]
|
||||
buffer = self._get_buffer(qlen)
|
||||
self._copy_grad_output_to_cpu(buffer, grad_output, qlen)
|
||||
|
||||
self.cpu_infer.submit(self._make_backward_task(buffer))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
self._cache_depth -= 1
|
||||
return self._return_grads(buffer, qlen, output_device)
|
||||
|
||||
# ========== Async forward ==========
|
||||
|
||||
def submit_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
save_for_backward: bool = True,
|
||||
) -> None:
|
||||
"""Submit forward pass asynchronously (non-blocking). Call sync_forward() to get results."""
|
||||
self._validate_forward_inputs(hidden_states, expert_ids, weights)
|
||||
qlen = hidden_states.shape[0]
|
||||
buffer = self._get_buffer(qlen)
|
||||
self._copy_inputs_to_buffer(buffer, hidden_states, expert_ids, weights, qlen)
|
||||
|
||||
self._pending_buffer = buffer
|
||||
self._pending_save_for_backward = save_for_backward
|
||||
self._pending_qlen = qlen
|
||||
|
||||
self.cpu_infer.submit(self._make_forward_task(buffer, save_for_backward))
|
||||
|
||||
def sync_forward(self, output_device: Optional[torch.device] = None) -> torch.Tensor:
|
||||
"""Synchronize and retrieve forward results. Must be called after submit_forward()."""
|
||||
if not hasattr(self, "_pending_buffer") or self._pending_buffer is None:
|
||||
raise RuntimeError("No pending forward. Call submit_forward() first.")
|
||||
|
||||
self.cpu_infer.sync()
|
||||
|
||||
buffer = self._pending_buffer
|
||||
save_for_backward = self._pending_save_for_backward
|
||||
qlen = self._pending_qlen
|
||||
|
||||
if save_for_backward and self._cache_depth == 0:
|
||||
self._cache_depth += 1
|
||||
|
||||
self._pending_buffer = None
|
||||
self._pending_save_for_backward = None
|
||||
self._pending_qlen = None
|
||||
|
||||
return self._return_output(buffer, qlen, output_device)
|
||||
|
||||
# ========== Async backward ==========
|
||||
|
||||
def submit_backward_async(
|
||||
self,
|
||||
grad_output: torch.Tensor,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> None:
|
||||
"""Submit backward task without waiting. Call sync_backward() for results."""
|
||||
if self._cache_depth <= 0:
|
||||
raise RuntimeError("No forward cache available. Call forward(save_for_backward=True) first.")
|
||||
|
||||
qlen = grad_output.shape[0]
|
||||
buffer = self._get_buffer(qlen)
|
||||
self._copy_grad_output_to_cpu(buffer, grad_output, qlen)
|
||||
|
||||
self.cpu_infer.submit(self._make_backward_task(buffer))
|
||||
self._async_bwd_qlen = qlen
|
||||
self._async_bwd_output_device = output_device
|
||||
|
||||
def sync_backward(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Wait for async backward and return results."""
|
||||
self.cpu_infer.sync()
|
||||
|
||||
qlen = self._async_bwd_qlen
|
||||
output_device = self._async_bwd_output_device
|
||||
buffer = self._get_buffer(qlen)
|
||||
|
||||
self._cache_depth -= 1
|
||||
return self._return_grads(buffer, qlen, output_device)
|
||||
|
||||
# ========== Backward repack (optional, subclasses may override) ==========
|
||||
|
||||
def submit_backward_repack(self):
|
||||
if not self._weights_loaded or self.moe is None:
|
||||
return
|
||||
if hasattr(self.moe, 'submit_backward_repack'):
|
||||
self.moe.submit_backward_repack()
|
||||
|
||||
def wait_backward_repack(self):
|
||||
if not self._weights_loaded or self.moe is None:
|
||||
return
|
||||
if hasattr(self.moe, 'wait_backward_repack'):
|
||||
self.moe.wait_backward_repack()
|
||||
139
kt-kernel/python/sft/config.py
Normal file
139
kt-kernel/python/sft/config.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
# KT-Kernel SFT configuration
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
KTConfig: kt-kernel's own configuration dataclass.
|
||||
|
||||
This is the kt-kernel equivalent of DeepSpeed's JSON config —
|
||||
it holds all kt-kernel-specific settings and is passed through
|
||||
KTransformersPlugin.kt_config (similar to DeepSpeedPlugin.hf_ds_config).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def _env_int(key: str, default: int | None) -> int | None:
|
||||
value = os.environ.get(key, None)
|
||||
if value is None or value == "":
|
||||
return default
|
||||
return int(value)
|
||||
|
||||
|
||||
def _env_float(key: str, default: float | None) -> float | None:
|
||||
value = os.environ.get(key, None)
|
||||
if value is None or value == "":
|
||||
return default
|
||||
return float(value)
|
||||
|
||||
|
||||
def _env_bool(key: str, default: bool) -> bool:
|
||||
value = os.environ.get(key, None)
|
||||
if value is None or value == "":
|
||||
return default
|
||||
return value.lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTConfig:
|
||||
"""
|
||||
KT-Kernel configuration for SFT training.
|
||||
|
||||
All field names use the ``kt_`` prefix so they match the dict keys used in
|
||||
HfTrainerKTConfig / YAML configs. This means ``KTConfig(**dict)`` works
|
||||
directly — no name-mapping or prefix-stripping needed.
|
||||
|
||||
Can be created from:
|
||||
- Direct construction: KTConfig(kt_backend="AMXBF16", kt_weight_path="/path/...")
|
||||
- Dict: KTConfig(**config_dict)
|
||||
- Environment variables: KTConfig() reads ACCELERATE_KT_* env vars as defaults
|
||||
"""
|
||||
|
||||
# Backend selection
|
||||
kt_backend: str | None = None
|
||||
kt_num_threads: int | None = None
|
||||
kt_tp_enabled: bool | None = None
|
||||
kt_threadpool_count: int | None = None
|
||||
|
||||
# Weight loading
|
||||
kt_weight_path: str | None = None
|
||||
kt_expert_checkpoint_path: str | None = None
|
||||
kt_num_gpu_experts: int | None = None
|
||||
kt_skip_expert_loading: bool | None = None
|
||||
kt_share_backward_bb: bool | None = None # default True — always saves memory
|
||||
kt_share_cache_pool: bool | None = None # auto-set by trainer_config_process, not user-facing
|
||||
|
||||
# Cache
|
||||
kt_max_cache_depth: int | None = None
|
||||
kt_model_max_length: int | None = None
|
||||
|
||||
# LoRA
|
||||
kt_lora_rank: int | None = None
|
||||
kt_lora_alpha: float | None = None
|
||||
|
||||
# LoRA Experts (GPU-side extra experts)
|
||||
kt_use_lora_experts: bool | None = None
|
||||
kt_lora_expert_num: int | None = None
|
||||
kt_lora_expert_intermediate_size: int | None = None
|
||||
|
||||
# Runtime state (set during wrapping, not by user)
|
||||
kt_checkpoint_files: list[str] | None = None
|
||||
kt_sharded_metadata: dict | None = None
|
||||
|
||||
# Custom wrapping
|
||||
kt_wrap_fn: Callable[..., Any] | None = None
|
||||
kt_wrap_kwargs: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_object(cls, obj: Any) -> "KTConfig":
|
||||
"""Create KTConfig from an attribute-based object (HfTrainerKTConfig, etc.)."""
|
||||
_field_names = {f.name for f in dataclasses.fields(cls)}
|
||||
kwargs: dict[str, Any] = {}
|
||||
for name in _field_names:
|
||||
val = getattr(obj, name, None)
|
||||
if val is not None:
|
||||
kwargs[name] = val
|
||||
return cls(**kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.kt_backend is None:
|
||||
self.kt_backend = os.environ.get("ACCELERATE_KT_BACKEND", "AMXBF16")
|
||||
if self.kt_num_threads is None:
|
||||
self.kt_num_threads = _env_int("ACCELERATE_KT_NUM_THREADS", 1)
|
||||
if self.kt_tp_enabled is None:
|
||||
self.kt_tp_enabled = _env_bool("ACCELERATE_KT_TP_ENABLED", False)
|
||||
if self.kt_threadpool_count is None:
|
||||
self.kt_threadpool_count = _env_int("ACCELERATE_KT_THREADPOOL_COUNT", 1)
|
||||
if self.kt_weight_path is None:
|
||||
self.kt_weight_path = os.environ.get("ACCELERATE_KT_WEIGHT_PATH", None)
|
||||
if self.kt_expert_checkpoint_path is None:
|
||||
self.kt_expert_checkpoint_path = os.environ.get("ACCELERATE_KT_EXPERT_CHECKPOINT_PATH", None)
|
||||
if self.kt_num_gpu_experts is None:
|
||||
self.kt_num_gpu_experts = _env_int("ACCELERATE_KT_NUM_GPU_EXPERTS", 0)
|
||||
if self.kt_max_cache_depth is None:
|
||||
self.kt_max_cache_depth = _env_int("ACCELERATE_KT_MAX_CACHE_DEPTH", 2)
|
||||
if self.kt_share_backward_bb is None:
|
||||
self.kt_share_backward_bb = _env_bool("ACCELERATE_KT_SHARE_BACKWARD_BB", True)
|
||||
if self.kt_share_cache_pool is None:
|
||||
self.kt_share_cache_pool = False
|
||||
if self.kt_use_lora_experts is None:
|
||||
self.kt_use_lora_experts = _env_bool("ACCELERATE_KT_USE_LORA_EXPERTS", False)
|
||||
if self.kt_lora_expert_num is None:
|
||||
self.kt_lora_expert_num = _env_int("ACCELERATE_KT_LORA_EXPERT_NUM", None)
|
||||
if self.kt_lora_expert_intermediate_size is None:
|
||||
self.kt_lora_expert_intermediate_size = _env_int("ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE", None)
|
||||
if self.kt_lora_rank is None:
|
||||
self.kt_lora_rank = _env_int("ACCELERATE_KT_LORA_RANK", None)
|
||||
if self.kt_lora_alpha is None:
|
||||
self.kt_lora_alpha = _env_float("ACCELERATE_KT_LORA_ALPHA", None)
|
||||
if self.kt_lora_alpha is None and self.kt_lora_rank is not None:
|
||||
self.kt_lora_alpha = float(self.kt_lora_rank * 2)
|
||||
if self.kt_model_max_length is None:
|
||||
self.kt_model_max_length = _env_int("ACCELERATE_KT_MODEL_MAX_LENGTH", None)
|
||||
if self.kt_skip_expert_loading is None:
|
||||
if "ACCELERATE_KT_SKIP_EXPERT_LOADING" in os.environ:
|
||||
self.kt_skip_expert_loading = _env_bool("ACCELERATE_KT_SKIP_EXPERT_LOADING", True)
|
||||
171
kt-kernel/python/sft/dist_utils.py
Normal file
171
kt-kernel/python/sft/dist_utils.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
# Distributed and checkpoint utilities for SFT
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
Shared distributed communication and gradient-checkpoint detection helpers.
|
||||
|
||||
This is a leaf module — no imports from other sft/ submodules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _all_gather_qlens(local_qlen: int, device: torch.device, world_size: int) -> list[int]:
|
||||
import torch.distributed as dist
|
||||
|
||||
local_qlen_t = torch.tensor([int(local_qlen)], device=device, dtype=torch.int64)
|
||||
gathered = [torch.empty(1, device=device, dtype=torch.int64) for _ in range(world_size)]
|
||||
dist.all_gather(gathered, local_qlen_t)
|
||||
return [int(t.item()) for t in gathered]
|
||||
|
||||
|
||||
def _qlen_offsets(all_qlens: list[int]) -> list[int]:
|
||||
offsets = [0]
|
||||
for q in all_qlens:
|
||||
offsets.append(offsets[-1] + int(q))
|
||||
return offsets
|
||||
|
||||
|
||||
def _dist_gather_varlen_to_rank0(
|
||||
local_tensor: torch.Tensor,
|
||||
*,
|
||||
all_qlens: list[int],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
) -> list[torch.Tensor] | None:
|
||||
import torch.distributed as dist
|
||||
|
||||
local_tensor = local_tensor.contiguous()
|
||||
local_expected = int(all_qlens[rank])
|
||||
if local_tensor.shape[0] != local_expected:
|
||||
raise RuntimeError(
|
||||
f"Local leading dim mismatch on rank {rank}: got {local_tensor.shape[0]}, expected {local_expected}"
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
gathered: list[torch.Tensor | None] = [None] * world_size
|
||||
gathered[0] = local_tensor
|
||||
ops: list[dist.P2POp] = []
|
||||
for src in range(1, world_size):
|
||||
qlen_src = int(all_qlens[src])
|
||||
recv_shape = (qlen_src, *local_tensor.shape[1:])
|
||||
recv = torch.empty(recv_shape, device=local_tensor.device, dtype=local_tensor.dtype)
|
||||
gathered[src] = recv
|
||||
if qlen_src > 0:
|
||||
ops.append(dist.P2POp(dist.irecv, recv, src))
|
||||
if ops:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
out: list[torch.Tensor] = []
|
||||
for idx, t in enumerate(gathered):
|
||||
if t is None:
|
||||
raise RuntimeError(f"Missing gathered tensor for rank {idx} on rank0.")
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
if local_expected > 0:
|
||||
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, local_tensor, 0)])
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
return None
|
||||
|
||||
|
||||
def _dist_scatter_varlen_from_rank0(
|
||||
*,
|
||||
rank0_chunks: list[torch.Tensor] | None,
|
||||
all_qlens: list[int],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
feature_shape: tuple[int, ...],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
import torch.distributed as dist
|
||||
|
||||
local_qlen = int(all_qlens[rank])
|
||||
local_out = torch.empty((local_qlen, *feature_shape), device=device, dtype=dtype)
|
||||
|
||||
if rank == 0:
|
||||
if rank0_chunks is None or len(rank0_chunks) != world_size:
|
||||
raise RuntimeError("rank0_chunks must contain one chunk per rank on rank0.")
|
||||
if int(rank0_chunks[0].shape[0]) != local_qlen:
|
||||
raise RuntimeError(
|
||||
f"Rank0 local chunk mismatch: got {rank0_chunks[0].shape[0]}, expected {local_qlen}"
|
||||
)
|
||||
if local_qlen > 0:
|
||||
local_out.copy_(rank0_chunks[0])
|
||||
ops: list[dist.P2POp] = []
|
||||
for dst in range(1, world_size):
|
||||
qlen_dst = int(all_qlens[dst])
|
||||
if qlen_dst <= 0:
|
||||
continue
|
||||
chunk = rank0_chunks[dst].contiguous()
|
||||
if int(chunk.shape[0]) != qlen_dst:
|
||||
raise RuntimeError(
|
||||
f"Rank{dst} chunk mismatch on rank0: got {chunk.shape[0]}, expected {qlen_dst}"
|
||||
)
|
||||
ops.append(dist.P2POp(dist.isend, chunk, dst))
|
||||
if ops:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
return local_out
|
||||
|
||||
if local_qlen > 0:
|
||||
reqs = dist.batch_isend_irecv([dist.P2POp(dist.irecv, local_out, 0)])
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
return local_out
|
||||
|
||||
|
||||
|
||||
def _checkpoint_hook_mode() -> str:
|
||||
"""Infer checkpoint phase from current saved_tensors_hooks top.
|
||||
|
||||
Returns one of:
|
||||
- "first_forward": non-reentrant checkpoint's _checkpoint_hook
|
||||
- "recompute": non-reentrant checkpoint's _recomputation_hook
|
||||
- "none": no default saved_tensors_hooks on top
|
||||
- "other": unknown hook stack entry
|
||||
- "error": failed to query hook stack
|
||||
"""
|
||||
try:
|
||||
top = torch._C._autograd._top_saved_tensors_default_hooks(False)
|
||||
except Exception:
|
||||
return "error"
|
||||
if top is None:
|
||||
return "none"
|
||||
try:
|
||||
pack_fn, _ = top
|
||||
mod = getattr(pack_fn, "__module__", "")
|
||||
qual = getattr(pack_fn, "__qualname__", getattr(pack_fn, "__name__", ""))
|
||||
tag = f"{mod}.{qual}"
|
||||
except Exception:
|
||||
return "other"
|
||||
if "_recomputation_hook.__init__.<locals>.pack_hook" in tag:
|
||||
return "recompute"
|
||||
if "_checkpoint_hook.__init__.<locals>.pack_hook" in tag:
|
||||
return "first_forward"
|
||||
return "other"
|
||||
|
||||
|
||||
def _maybe_zero3_gathered_parameters(params: list[torch.nn.Parameter]):
|
||||
if not params:
|
||||
return nullcontext()
|
||||
try:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except Exception:
|
||||
return nullcontext()
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
return nullcontext()
|
||||
try:
|
||||
import deepspeed # type: ignore
|
||||
except Exception:
|
||||
return nullcontext()
|
||||
return deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||
399
kt-kernel/python/sft/layer.py
Normal file
399
kt-kernel/python/sft/layer.py
Normal file
|
|
@ -0,0 +1,399 @@
|
|||
# KTMoELayerWrapper — nn.Module replacing HF MoE layers for SFT
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
KTMoELayerWrapper: drop-in nn.Module replacement for HuggingFace MoE layers.
|
||||
|
||||
Delegates expert computation to the C++ KTMoEWrapper backend, with support
|
||||
for gradient checkpointing, PEFT LoRA on experts, LoRA Experts (separate
|
||||
small MLPs on GPU), shared experts, and multi-GPU rank-0-only execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .arch import MOEArchConfig
|
||||
from .autograd import KTMoEFunction
|
||||
from .dist_utils import (
|
||||
_all_gather_qlens,
|
||||
_checkpoint_hook_mode,
|
||||
_dist_gather_varlen_to_rank0,
|
||||
_dist_scatter_varlen_from_rank0,
|
||||
_qlen_offsets,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_KT_SFT_DEBUG = os.environ.get("KT_SFT_DEBUG", "0") == "1"
|
||||
|
||||
|
||||
class KTMoELayerWrapper(nn.Module):
|
||||
"""Wrapper for MoE layer using KTMoEWrapper."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_moe: nn.Module,
|
||||
wrapper: Any,
|
||||
lora_params: dict[str, nn.Parameter] | None, # Kept for backward compatibility, but ignored
|
||||
moe_config: MOEArchConfig,
|
||||
hidden_size: int,
|
||||
layer_idx: int,
|
||||
lora_experts: "LoRAExperts | None" = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._is_kt_moe_wrapper = True
|
||||
|
||||
self.wrapper = wrapper
|
||||
self.moe_config = moe_config
|
||||
self.hidden_size = hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.router_type = moe_config.router_type
|
||||
|
||||
# IMPORTANT: Register submodules in the SAME ORDER as original MoE module
|
||||
# so that PEFT's named_modules() traversal order matches baseline.
|
||||
# This ensures kaiming_uniform_ calls happen in the same sequence.
|
||||
# Qwen3MoeSparseMoeBlock order: gate FIRST, then experts.
|
||||
|
||||
# 1. gate/router FIRST - keep original attribute name for PEFT compatibility
|
||||
router_attr = moe_config.router_attr # "gate" for Qwen3/DeepSeek
|
||||
setattr(self, router_attr, getattr(original_moe, router_attr, None))
|
||||
self._router_attr = router_attr
|
||||
|
||||
# 2. experts SECOND (this is what PEFT targets for LoRA)
|
||||
experts_attr = moe_config.experts_attr # typically "experts"
|
||||
setattr(self, experts_attr, getattr(original_moe, experts_attr, None))
|
||||
self._experts_attr = experts_attr
|
||||
|
||||
# 3. shared_experts (if any)
|
||||
if moe_config.has_shared_experts and hasattr(original_moe, "shared_experts"):
|
||||
self.shared_experts = original_moe.shared_experts
|
||||
else:
|
||||
self.shared_experts = None
|
||||
|
||||
# 4. lora_experts (separate LoRA expert MLPs, different from PEFT LoRA on experts)
|
||||
self.lora_experts = lora_experts
|
||||
|
||||
# PEFT LoRA tracking (set by kt_adapt_peft_lora)
|
||||
# _peft_lora_modules: {expert_idx: {proj_name: (lora_A, lora_B)}}
|
||||
self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None
|
||||
self._lora_pointers_dirty = False
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
# Protect experts from device transfer (PEFT LoRA should stay on CPU for KT)
|
||||
saved_experts = None
|
||||
experts_attr = getattr(self, '_experts_attr', None)
|
||||
|
||||
if experts_attr is not None and getattr(self, experts_attr, None) is not None:
|
||||
saved_experts = getattr(self, experts_attr)
|
||||
self._modules.pop(experts_attr, None)
|
||||
|
||||
result = super()._apply(fn, recurse)
|
||||
|
||||
if saved_experts is not None:
|
||||
self._modules[experts_attr] = saved_experts
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
import torch.distributed as dist
|
||||
dist_on = dist.is_initialized() and dist.get_world_size() > 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
# Check if we need to use distributed broadcast (only rank 0 has KT kernel)
|
||||
use_broadcast = dist_on and self.wrapper is None
|
||||
|
||||
topk_ids, topk_weights = self._compute_routing(hidden_states)
|
||||
|
||||
train_lora = self._peft_lora_modules is not None and len(self._peft_lora_modules) > 0
|
||||
|
||||
save_for_backward = (
|
||||
self.training
|
||||
and torch.is_grad_enabled()
|
||||
and (hidden_states.requires_grad or topk_weights.requires_grad or train_lora)
|
||||
)
|
||||
use_autograd_path = save_for_backward
|
||||
save_for_backward_submit = use_autograd_path
|
||||
if _checkpoint_hook_mode() == "first_forward":
|
||||
save_for_backward_submit = False
|
||||
|
||||
if train_lora and self._lora_pointers_dirty:
|
||||
self.update_lora_pointers()
|
||||
self._lora_pointers_dirty = False
|
||||
|
||||
gpu_output, all_qlens = self._submit_and_compute_gpu(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
save_for_backward_submit,
|
||||
)
|
||||
|
||||
# Use KTMoEFunction whenever backward is needed so KT backward and LoRA
|
||||
# gradient paths remain connected.
|
||||
if use_autograd_path:
|
||||
lora_ref = hidden_states.new_empty(())
|
||||
if train_lora and self._peft_lora_modules:
|
||||
for expert_loras in self._peft_lora_modules.values():
|
||||
for lora_A, lora_B in expert_loras.values():
|
||||
if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad:
|
||||
lora_ref = lora_A.weight
|
||||
break
|
||||
if lora_ref.numel() > 0:
|
||||
break
|
||||
|
||||
moe_output = KTMoEFunction.apply(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
self.wrapper,
|
||||
lora_ref,
|
||||
self.hidden_size,
|
||||
self.moe_config.num_experts_per_tok,
|
||||
self.layer_idx,
|
||||
save_for_backward,
|
||||
train_lora,
|
||||
all_qlens,
|
||||
)
|
||||
else:
|
||||
moe_output = self._sync_forward_output_no_autograd(
|
||||
hidden_states=hidden_states,
|
||||
all_qlens=all_qlens,
|
||||
)
|
||||
|
||||
if gpu_output is not None:
|
||||
moe_output = moe_output + gpu_output
|
||||
|
||||
return moe_output
|
||||
|
||||
def _sync_forward_output_no_autograd(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
all_qlens: list[int] | tuple[int, ...] | None,
|
||||
) -> torch.Tensor:
|
||||
"""Sync CPU expert output without creating KTMoEFunction autograd nodes."""
|
||||
import torch.distributed as dist
|
||||
|
||||
original_device = hidden_states.device
|
||||
original_dtype = hidden_states.dtype
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
qlen = batch_size * seq_len
|
||||
|
||||
dist_on = dist.is_initialized() and dist.get_world_size() > 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist_on else 1
|
||||
|
||||
if dist_on:
|
||||
if all_qlens is None:
|
||||
all_qlens_list = _all_gather_qlens(qlen, original_device, world_size)
|
||||
else:
|
||||
all_qlens_list = [int(q) for q in all_qlens]
|
||||
if len(all_qlens_list) != world_size:
|
||||
raise RuntimeError(
|
||||
f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}"
|
||||
)
|
||||
if int(all_qlens_list[rank]) != qlen:
|
||||
raise RuntimeError(
|
||||
f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}"
|
||||
)
|
||||
total_qlen = sum(all_qlens_list)
|
||||
|
||||
if rank == 0:
|
||||
if self.wrapper is None:
|
||||
raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.")
|
||||
cpu_output = self.wrapper.sync_forward(output_device=original_device)
|
||||
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size)
|
||||
offsets = _qlen_offsets(all_qlens_list)
|
||||
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
|
||||
else:
|
||||
scatter_list = None
|
||||
|
||||
output_flat = _dist_scatter_varlen_from_rank0(
|
||||
rank0_chunks=scatter_list,
|
||||
all_qlens=all_qlens_list,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
feature_shape=(self.hidden_size,),
|
||||
device=original_device,
|
||||
dtype=original_dtype,
|
||||
)
|
||||
output = output_flat.view(batch_size, seq_len, self.hidden_size)
|
||||
del output_flat
|
||||
return output
|
||||
|
||||
if self.wrapper is not None:
|
||||
cpu_output = self.wrapper.sync_forward(output_device=original_device)
|
||||
output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype)
|
||||
return output
|
||||
|
||||
return torch.empty(batch_size, seq_len, self.hidden_size, device=original_device, dtype=original_dtype)
|
||||
|
||||
def _compute_routing(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Run routing under no_grad to avoid creating autograd nodes whose
|
||||
# SavedVariables become orphan holders inside gradient checkpoint.
|
||||
# The gate is frozen during LoRA fine-tuning and the main gradient
|
||||
# flows through KTMoEFunction.backward()'s grad_input, so the
|
||||
# routing gradient contribution to hidden_states can be safely dropped.
|
||||
with torch.no_grad():
|
||||
router = getattr(self, self._router_attr)
|
||||
if self.router_type == "deepseek_gate":
|
||||
# DeepSeek V3's MoEGate has `assert not self.training` in its noaux_tc
|
||||
# routing path because the HF model is an inference-only port.
|
||||
# For LoRA fine-tuning the router is frozen, so eval() is safe.
|
||||
was_training = router.training
|
||||
if was_training:
|
||||
router.eval()
|
||||
router_output = router(hidden_states)
|
||||
if was_training:
|
||||
router.train()
|
||||
if len(router_output) == 2:
|
||||
topk_ids, topk_weights = router_output
|
||||
else:
|
||||
topk_ids, topk_weights = router_output[0], router_output[1]
|
||||
if topk_weights.is_floating_point():
|
||||
topk_weights = topk_weights.to(torch.bfloat16)
|
||||
return topk_ids, topk_weights
|
||||
|
||||
router_output = router(hidden_states.view(-1, self.hidden_size))
|
||||
# transformers v5 TopKRouter returns (router_logits, router_scores, router_indices)
|
||||
# directly — scores/indices are already topk-normalized.
|
||||
if isinstance(router_output, tuple):
|
||||
if len(router_output) >= 3:
|
||||
_logits, topk_weights, topk_ids = router_output[0], router_output[1], router_output[2]
|
||||
if topk_weights.is_floating_point():
|
||||
topk_weights = topk_weights.to(torch.bfloat16)
|
||||
return topk_ids, topk_weights
|
||||
router_output = router_output[0]
|
||||
|
||||
router_logits = router_output
|
||||
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(routing_weights, self.moe_config.num_experts_per_tok, dim=-1)
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(torch.bfloat16)
|
||||
return topk_ids, topk_weights
|
||||
|
||||
def _submit_and_compute_gpu(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
save_for_backward: bool,
|
||||
) -> tuple[torch.Tensor | None, list[int] | None]:
|
||||
import torch.distributed as dist
|
||||
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
original_device = hidden_states.device
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
dist_on = dist.is_initialized() and dist.get_world_size() > 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist_on else 1
|
||||
|
||||
qlen = batch_size * seq_len
|
||||
|
||||
if dist_on:
|
||||
all_qlens = _all_gather_qlens(qlen, original_device, world_size)
|
||||
if int(all_qlens[rank]) != qlen:
|
||||
raise RuntimeError(
|
||||
f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}"
|
||||
)
|
||||
total_qlen = sum(all_qlens)
|
||||
|
||||
hs_flat = hidden_states.view(qlen, self.hidden_size).contiguous()
|
||||
expert_ids = topk_ids.view(qlen, self.moe_config.num_experts_per_tok).contiguous()
|
||||
weights = topk_weights.view(qlen, self.moe_config.num_experts_per_tok).contiguous()
|
||||
|
||||
submit_hs = hs_flat.detach()
|
||||
submit_ids = expert_ids.detach()
|
||||
submit_wts = weights.detach()
|
||||
|
||||
gathered_hs = _dist_gather_varlen_to_rank0(
|
||||
submit_hs,
|
||||
all_qlens=all_qlens,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
gathered_ids = _dist_gather_varlen_to_rank0(
|
||||
submit_ids,
|
||||
all_qlens=all_qlens,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
gathered_wts = _dist_gather_varlen_to_rank0(
|
||||
submit_wts,
|
||||
all_qlens=all_qlens,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
all_hs = torch.cat(gathered_hs, dim=0)
|
||||
all_ids = torch.cat(gathered_ids, dim=0)
|
||||
all_wts = torch.cat(gathered_wts, dim=0)
|
||||
self.wrapper.submit_forward(
|
||||
all_hs,
|
||||
all_ids,
|
||||
all_wts,
|
||||
save_for_backward=save_for_backward,
|
||||
)
|
||||
|
||||
# Keep shared/lora experts local to avoid qlen_max-style amplification.
|
||||
gpu_output = None
|
||||
if self.shared_experts is not None:
|
||||
gpu_output = self.shared_experts(hidden_states)
|
||||
gpu_output = gpu_output.to(dtype=original_dtype)
|
||||
|
||||
if self.lora_experts is not None:
|
||||
lora_out = self.lora_experts(hidden_states)
|
||||
gpu_output = lora_out if gpu_output is None else gpu_output + lora_out
|
||||
|
||||
return gpu_output, all_qlens
|
||||
|
||||
else:
|
||||
# ---- Single-GPU path: submit + GPU compute ----
|
||||
input_flat = hidden_states.view(qlen, self.hidden_size)
|
||||
expert_ids = topk_ids.view(qlen, self.moe_config.num_experts_per_tok)
|
||||
weights = topk_weights.view(qlen, self.moe_config.num_experts_per_tok)
|
||||
|
||||
# Avoid passing graph-attached tensors into C++ cache.
|
||||
submit_hs = input_flat.detach()
|
||||
submit_ids = expert_ids.detach()
|
||||
submit_wts = weights.detach()
|
||||
self.wrapper.submit_forward(
|
||||
submit_hs,
|
||||
submit_ids,
|
||||
submit_wts,
|
||||
save_for_backward=save_for_backward,
|
||||
)
|
||||
|
||||
# GPU compute: shared_experts + lora_experts
|
||||
gpu_output = None
|
||||
if self.shared_experts is not None:
|
||||
gpu_output = self.shared_experts(hidden_states)
|
||||
if self.lora_experts is not None:
|
||||
lora_out = self.lora_experts(hidden_states)
|
||||
gpu_output = lora_out if gpu_output is None else gpu_output + lora_out
|
||||
|
||||
return gpu_output, None
|
||||
|
||||
def update_lora_pointers(self):
|
||||
"""Sync PEFT LoRA weights to C++ kernel after optimizer update."""
|
||||
# Skip if wrapper is None (non-rank-0 processes)
|
||||
if self.wrapper is None:
|
||||
return
|
||||
# Skip if wrapper is not properly initialized
|
||||
if not getattr(self.wrapper, "_weights_loaded", False):
|
||||
logger.warning(f"Layer {self.layer_idx}: Skipping update_lora_pointers - weights not loaded")
|
||||
return
|
||||
if not getattr(self.wrapper, "_lora_initialized", False):
|
||||
logger.warning(f"Layer {self.layer_idx}: Skipping update_lora_pointers - LoRA not initialized")
|
||||
return
|
||||
|
||||
# PEFT weights are views into wrapper's contiguous buffers —
|
||||
# optimizer.step() already updated them in-place, just re-sync to C++.
|
||||
self.wrapper.update_lora_weights()
|
||||
803
kt-kernel/python/sft/lora.py
Normal file
803
kt-kernel/python/sft/lora.py
Normal file
|
|
@ -0,0 +1,803 @@
|
|||
# PEFT LoRA adaptation utilities for SFT
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
PEFT LoRA integration for KT-Kernel MoE training.
|
||||
|
||||
Handles:
|
||||
- LoRA Expert modules (LoRAExpertMLP, LoRAExperts)
|
||||
- PEFT LoRA adaptation onto KT wrappers (contiguous buffer views, grad buffers)
|
||||
- LoRA parameter collection for optimizer injection
|
||||
- Checkpoint save/load for lora_experts
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .arch import MOEArchConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA Experts Modules
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class LoRAExpertMLP(nn.Module):
|
||||
"""Single LoRA Expert with SwiGLU activation structure."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self.le_gate = nn.Linear(hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.le_up = nn.Linear(hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.le_down = nn.Linear(intermediate_size, hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
nn.init.zeros_(self.le_down.weight)
|
||||
nn.init.kaiming_uniform_(self.le_gate.weight, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.le_up.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.le_down(self.act_fn(self.le_gate(x)) * self.le_up(x))
|
||||
|
||||
|
||||
class LoRAExperts(nn.Module):
|
||||
"""LoRA Experts module containing multiple LoRA Expert MLPs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self.experts = nn.ModuleList(
|
||||
[LoRAExpertMLP(hidden_size, intermediate_size, device, dtype) for _ in range(num_experts)]
|
||||
)
|
||||
self.num_experts = num_experts
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
for expert in self.experts:
|
||||
output = output + expert(hidden_states)
|
||||
return output / self.num_experts
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA Parameter Collection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _find_kt_wrappers(model: nn.Module):
|
||||
"""Find _kt_wrappers on model, unwrapping PEFT/other wrappers if needed."""
|
||||
wrappers = getattr(model, "_kt_wrappers", None)
|
||||
if wrappers is None:
|
||||
base_model = model
|
||||
for attr in ("base_model", "model"):
|
||||
if hasattr(base_model, attr):
|
||||
base_model = getattr(base_model, attr)
|
||||
wrappers = getattr(base_model, "_kt_wrappers", None)
|
||||
if wrappers:
|
||||
break
|
||||
return wrappers
|
||||
|
||||
|
||||
def get_kt_lora_params(model: nn.Module) -> list[nn.Parameter]:
|
||||
"""Get all MoE LoRA parameters from KT model.
|
||||
|
||||
Returns PEFT LoRA parameters from expert modules and lora_experts parameters.
|
||||
"""
|
||||
params: list[nn.Parameter] = []
|
||||
|
||||
wrappers = _find_kt_wrappers(model)
|
||||
|
||||
if wrappers:
|
||||
for wrapper in wrappers:
|
||||
# PEFT LoRA parameters (from _peft_lora_modules)
|
||||
peft_lora_modules = getattr(wrapper, "_peft_lora_modules", None)
|
||||
if peft_lora_modules is not None:
|
||||
for expert_loras in peft_lora_modules.values():
|
||||
for lora_A, lora_B in expert_loras.values():
|
||||
if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad:
|
||||
params.append(lora_A.weight)
|
||||
if hasattr(lora_B, 'weight') and lora_B.weight.requires_grad:
|
||||
params.append(lora_B.weight)
|
||||
# Fused expert LoRA parameters (KT-managed, not PEFT)
|
||||
fused_params = getattr(wrapper, "_fused_expert_lora_params", None)
|
||||
if fused_params is not None:
|
||||
params.extend(fused_params)
|
||||
# lora_experts parameters (separate feature)
|
||||
if getattr(wrapper, "lora_experts", None) is not None:
|
||||
params.extend(wrapper.lora_experts.parameters())
|
||||
|
||||
return params
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PEFT LoRA Adaptation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def kt_adapt_peft_lora(model: nn.Module) -> None:
|
||||
"""
|
||||
Adapt PEFT LoRA on expert modules for KT kernel.
|
||||
|
||||
After PEFT injects LoRA adapters onto expert Linear modules, this function:
|
||||
1. Detects PEFT LoRA presence and rank on each wrapper's experts
|
||||
2. Stores references to PEFT LoRA modules on the wrapper (for backward gradient writing)
|
||||
3. Syncs initial PEFT LoRA weights to the C++ KT kernel (rank 0 only)
|
||||
|
||||
PEFT LoRA remains active and is managed by PEFT. No separate KT lora_params created.
|
||||
Optimizer updates PEFT LoRA directly, and KT kernel reads from PEFT LoRA on each forward.
|
||||
|
||||
Should be called after PEFT LoRA injection and before create_optimizer.
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
wrappers = _find_kt_wrappers(model)
|
||||
|
||||
if not wrappers:
|
||||
logger.info("[kt_adapt_peft_lora] No _kt_wrappers found, skipping")
|
||||
return
|
||||
|
||||
is_rank_0 = True
|
||||
if dist.is_initialized():
|
||||
is_rank_0 = dist.get_rank() == 0
|
||||
|
||||
adapted_count = 0
|
||||
for wrapper in wrappers:
|
||||
moe_config = wrapper.moe_config
|
||||
layer_idx = wrapper.layer_idx
|
||||
experts_attr = getattr(wrapper, "_experts_attr", "experts")
|
||||
experts = getattr(wrapper, experts_attr, None)
|
||||
|
||||
if experts is None:
|
||||
continue
|
||||
|
||||
# Fused experts (transformers v5): PEFT cannot auto-attach LoRA to packed
|
||||
# nn.Parameter tensors. Create KT-managed LoRA buffers with proper init,
|
||||
# wrap as nn.Parameter for optimizer, and pre-assign .grad for C++ backward.
|
||||
if getattr(wrapper, "_fused_experts", False):
|
||||
lora_rank = getattr(wrapper, "_lora_rank", 1)
|
||||
lora_buffers, lora_grad_buffers, lora_params = _create_fused_expert_lora_buffers(
|
||||
wrapper, moe_config, lora_rank, torch.bfloat16,
|
||||
)
|
||||
|
||||
if is_rank_0 and wrapper.wrapper is not None:
|
||||
all_buffers = {}
|
||||
all_buffers.update(lora_buffers)
|
||||
all_buffers.update(lora_grad_buffers)
|
||||
wrapper.wrapper.init_lora_weights(**all_buffers)
|
||||
logger.info(
|
||||
f"[kt_adapt_peft_lora] Layer {layer_idx}: fused expert LoRA "
|
||||
f"(r={lora_rank}, E={moe_config.expert_num})"
|
||||
)
|
||||
|
||||
wrapper._fused_expert_lora_params = lora_params
|
||||
wrapper._peft_lora_modules = None
|
||||
adapted_count += 1
|
||||
continue
|
||||
|
||||
if len(experts) == 0:
|
||||
continue
|
||||
|
||||
# Collect references to PEFT LoRA modules for each expert
|
||||
# Structure: {expert_idx: {proj_name: (lora_A_module, lora_B_module)}}
|
||||
peft_lora_modules = {}
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
|
||||
for expert_idx, expert in enumerate(experts):
|
||||
expert_loras = {}
|
||||
for proj_name in (gate_name, up_name, down_name):
|
||||
proj = getattr(expert, proj_name, None)
|
||||
if proj is None:
|
||||
continue
|
||||
lora_A = getattr(proj, "lora_A", None)
|
||||
lora_B = getattr(proj, "lora_B", None)
|
||||
if lora_A is not None and lora_B is not None:
|
||||
# Get the actual Linear modules (inside ModuleDict if using adapters)
|
||||
if isinstance(lora_A, nn.ModuleDict):
|
||||
adapter_name = "default"
|
||||
active = getattr(proj, "active_adapter", ["default"])
|
||||
if isinstance(active, (list, tuple)) and active:
|
||||
adapter_name = active[0]
|
||||
# ModuleDict doesn't have .get(), use [] with in check
|
||||
lora_A = lora_A[adapter_name] if adapter_name in lora_A else None
|
||||
lora_B = lora_B[adapter_name] if adapter_name in lora_B else None
|
||||
if lora_A is not None and lora_B is not None:
|
||||
expert_loras[proj_name] = (lora_A, lora_B)
|
||||
if expert_loras:
|
||||
peft_lora_modules[expert_idx] = expert_loras
|
||||
|
||||
# Store PEFT LoRA references on wrapper
|
||||
wrapper._peft_lora_modules = peft_lora_modules
|
||||
|
||||
if not peft_lora_modules:
|
||||
raise RuntimeError(
|
||||
f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. "
|
||||
f"Check that PEFT lora_target includes expert modules."
|
||||
)
|
||||
|
||||
# Allocate contiguous bf16 buffers and populate with initial PEFT values (all ranks)
|
||||
lora_buffers = _create_lora_view_buffers(peft_lora_modules, moe_config, torch.bfloat16)
|
||||
lora_grad_buffers = _create_lora_grad_buffers(peft_lora_modules, moe_config)
|
||||
|
||||
# Rank 0: pass buffers to C++ wrapper (init_lora_weights stores them via .contiguous() no-op)
|
||||
if is_rank_0 and wrapper.wrapper is not None:
|
||||
# concat lora_buffers and lora_grad_buffers into single dict
|
||||
lora_buffers.update(lora_grad_buffers)
|
||||
wrapper.wrapper.init_lora_weights(**lora_buffers)
|
||||
logger.info(f"[kt_adapt_peft_lora] Layer {layer_idx}: synced PEFT LoRA to C++ kernel")
|
||||
|
||||
# All ranks: replace PEFT weights with views into the contiguous buffers
|
||||
_replace_peft_weights_with_views(peft_lora_modules, lora_buffers, lora_grad_buffers, moe_config)
|
||||
|
||||
adapted_count += 1
|
||||
|
||||
# After collecting all LoRA references, shrink expert base weight parameters
|
||||
# from their original shape (e.g. [768, 2048]) to scalar (1,).
|
||||
# These base weights were already replaced with tiny-storage stride=[0] placeholders
|
||||
# by _clear_original_expert_weights(). They have correct shape but serve no purpose
|
||||
# after PEFT injection. FSDP2 broadcasts ALL non-DTensor params, and uses
|
||||
# torch.empty(param.size()) on non-rank-0 — with the original shape this wastes
|
||||
# ~28GB+. Shrinking to (1,) reduces broadcast cost to ~30KB total.
|
||||
shrunk_count = 0
|
||||
shrunk_saved_bytes = 0
|
||||
for wrapper in wrappers:
|
||||
experts_attr = getattr(wrapper, "_experts_attr", "experts")
|
||||
experts = getattr(wrapper, experts_attr, None)
|
||||
if experts is None:
|
||||
continue
|
||||
if getattr(wrapper, "_fused_experts", False):
|
||||
continue
|
||||
for expert in experts:
|
||||
for param_name, param in list(expert.named_parameters()):
|
||||
if param.requires_grad:
|
||||
continue # Skip trainable params (LoRA weights)
|
||||
try:
|
||||
storage_bytes = param.data.untyped_storage().nbytes()
|
||||
except Exception:
|
||||
continue
|
||||
if storage_bytes > 2:
|
||||
continue # Skip non-placeholder params
|
||||
|
||||
# This is a tiny-storage placeholder (base weight) — replace with
|
||||
# a scalar (1,) parameter so FSDP broadcasts only 1 element.
|
||||
original_numel = param.nelement()
|
||||
parts = param_name.split(".")
|
||||
container = expert
|
||||
for p in parts[:-1]:
|
||||
container = getattr(container, p)
|
||||
local_name = parts[-1]
|
||||
container_params = getattr(container, "_parameters", {})
|
||||
if isinstance(container_params, dict) and local_name in container_params:
|
||||
scalar_param = nn.Parameter(
|
||||
torch.empty(1, dtype=param.dtype, device="cpu"),
|
||||
requires_grad=False,
|
||||
)
|
||||
container_params[local_name] = scalar_param
|
||||
shrunk_count += 1
|
||||
shrunk_saved_bytes += (original_numel - 1) * param.element_size()
|
||||
|
||||
if shrunk_count > 0:
|
||||
logger.info(
|
||||
f"[kt_adapt_peft_lora] Shrunk {shrunk_count} expert base weight params "
|
||||
f"to shape (1,), FSDP broadcast savings={shrunk_saved_bytes / 1024 / 1024:.1f} MB"
|
||||
)
|
||||
|
||||
logger.info(f"[kt_adapt_peft_lora] Adapted {adapted_count} layers (PEFT LoRA mode)")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Contiguous Buffer Creation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _create_lora_view_buffers(
|
||||
peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]],
|
||||
moe_config: MOEArchConfig,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Allocate contiguous buffers and populate with initial PEFT LoRA values.
|
||||
|
||||
Returns dict with gate_lora_a, gate_lora_b, up_lora_a, up_lora_b,
|
||||
down_lora_a, down_lora_b — each shape [num_experts, ...].
|
||||
"""
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
num_experts = moe_config.expert_num
|
||||
|
||||
first_expert_loras = peft_lora_modules.get(0, {})
|
||||
if not first_expert_loras:
|
||||
raise RuntimeError("No PEFT LoRA found on expert 0")
|
||||
gate_lora = first_expert_loras.get(gate_name)
|
||||
if gate_lora is None:
|
||||
raise RuntimeError(f"No PEFT LoRA found on expert 0 {gate_name}")
|
||||
|
||||
lora_rank = gate_lora[0].weight.shape[0]
|
||||
hidden_size = gate_lora[0].weight.shape[1]
|
||||
intermediate_size = gate_lora[1].weight.shape[0]
|
||||
|
||||
buffers = {
|
||||
"gate_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"),
|
||||
"gate_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"),
|
||||
"up_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"),
|
||||
"up_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"),
|
||||
"down_lora_a": torch.zeros(num_experts, lora_rank, intermediate_size, dtype=dtype, device="cpu"),
|
||||
"down_lora_b": torch.zeros(num_experts, hidden_size, lora_rank, dtype=dtype, device="cpu"),
|
||||
}
|
||||
|
||||
proj_to_keys = {
|
||||
gate_name: ("gate_lora_a", "gate_lora_b"),
|
||||
up_name: ("up_lora_a", "up_lora_b"),
|
||||
down_name: ("down_lora_a", "down_lora_b"),
|
||||
}
|
||||
for expert_idx in range(num_experts):
|
||||
expert_loras = peft_lora_modules.get(expert_idx, {})
|
||||
for proj_name, (key_a, key_b) in proj_to_keys.items():
|
||||
if proj_name in expert_loras:
|
||||
lora_A, lora_B = expert_loras[proj_name]
|
||||
buffers[key_a][expert_idx].copy_(lora_A.weight.data.to(dtype=dtype))
|
||||
buffers[key_b][expert_idx].copy_(lora_B.weight.data.to(dtype=dtype))
|
||||
|
||||
return buffers
|
||||
|
||||
|
||||
def _create_lora_grad_buffers(
|
||||
peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]],
|
||||
moe_config: MOEArchConfig,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Allocate contiguous gradient buffers for PEFT LoRA.
|
||||
|
||||
Returns dict with grad_gate_lora_a, grad_gate_lora_b, etc. — each shape [num_experts, ...].
|
||||
"""
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
num_experts = moe_config.expert_num
|
||||
|
||||
first_expert_loras = peft_lora_modules.get(0, {})
|
||||
if not first_expert_loras:
|
||||
raise RuntimeError("No PEFT LoRA found on expert 0")
|
||||
gate_lora = first_expert_loras.get(gate_name)
|
||||
if gate_lora is None:
|
||||
raise RuntimeError(f"No PEFT LoRA found on expert 0 {gate_name}")
|
||||
|
||||
lora_rank = gate_lora[0].weight.shape[0]
|
||||
hidden_size = gate_lora[0].weight.shape[1]
|
||||
intermediate_size = gate_lora[1].weight.shape[0]
|
||||
|
||||
buffers = {
|
||||
"grad_gate_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"),
|
||||
"grad_gate_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"),
|
||||
"grad_up_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"),
|
||||
"grad_up_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"),
|
||||
"grad_down_lora_a": torch.zeros(num_experts, lora_rank, intermediate_size, dtype=dtype, device="cpu"),
|
||||
"grad_down_lora_b": torch.zeros(num_experts, hidden_size, lora_rank, dtype=dtype, device="cpu"),
|
||||
}
|
||||
|
||||
return buffers
|
||||
|
||||
|
||||
def _create_fused_expert_lora_buffers(
|
||||
wrapper,
|
||||
moe_config: MOEArchConfig,
|
||||
lora_rank: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], list[nn.Parameter]]:
|
||||
"""
|
||||
Create KT-managed LoRA buffers for fused expert modules.
|
||||
|
||||
Fused experts store weights as 3D parameters (gate_up_proj [E, 2I, H], down_proj [E, H, I])
|
||||
rather than per-expert nn.Linear modules. PEFT can't attach per-expert LoRA to these,
|
||||
so we create our own LoRA buffers that the C++ kernel reads/writes directly.
|
||||
|
||||
Returns:
|
||||
(lora_buffers, lora_grad_buffers, lora_params):
|
||||
- lora_buffers: dict of weight buffers for C++ init_lora_weights()
|
||||
- lora_grad_buffers: dict of grad buffers for C++ backward
|
||||
- lora_params: list of nn.Parameter wrappers for the optimizer
|
||||
"""
|
||||
E = moe_config.expert_num
|
||||
I = moe_config.intermediate_size
|
||||
H = wrapper.hidden_size
|
||||
r = lora_rank
|
||||
|
||||
logger.info(f"[_create_fused_expert_lora_buffers] E={E}, I={I}, H={H}, r={r}")
|
||||
|
||||
lora_buffers = {
|
||||
"gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
|
||||
"gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
|
||||
"up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
|
||||
"up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
|
||||
"down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"),
|
||||
"down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"),
|
||||
}
|
||||
|
||||
for key in ("gate_lora_a", "up_lora_a", "down_lora_a"):
|
||||
nn.init.kaiming_uniform_(lora_buffers[key].view(E * r, -1), a=math.sqrt(5))
|
||||
|
||||
lora_grad_buffers = {
|
||||
"grad_gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
|
||||
"grad_gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
|
||||
"grad_up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"),
|
||||
"grad_up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"),
|
||||
"grad_down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"),
|
||||
"grad_down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"),
|
||||
}
|
||||
|
||||
lora_params = []
|
||||
for key in ("gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"):
|
||||
param = nn.Parameter(lora_buffers[key], requires_grad=True)
|
||||
param.grad = lora_grad_buffers[f"grad_{key}"]
|
||||
lora_params.append(param)
|
||||
|
||||
return lora_buffers, lora_grad_buffers, lora_params
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PEFT Weight View Replacement
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _replace_peft_weights_with_views(
|
||||
peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]],
|
||||
buffers: dict[str, torch.Tensor],
|
||||
grad_buffers: dict[str, torch.Tensor],
|
||||
moe_config: MOEArchConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Replace each PEFT LoRA module's .weight with a view into the contiguous buffer.
|
||||
|
||||
After this, optimizer.step() updates the buffer in-place via the view —
|
||||
no copy needed to sync with C++.
|
||||
"""
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
num_experts = moe_config.expert_num
|
||||
|
||||
proj_to_keys = {
|
||||
gate_name: ("gate_lora_a", "gate_lora_b"),
|
||||
up_name: ("up_lora_a", "up_lora_b"),
|
||||
down_name: ("down_lora_a", "down_lora_b"),
|
||||
}
|
||||
|
||||
_replaced = 0
|
||||
_first_logged = False
|
||||
for expert_idx in range(num_experts):
|
||||
expert_loras = peft_lora_modules.get(expert_idx, {})
|
||||
for proj_name, (key_a, key_b) in proj_to_keys.items():
|
||||
if proj_name not in expert_loras:
|
||||
continue
|
||||
lora_A, lora_B = expert_loras[proj_name]
|
||||
|
||||
# Log before/after for first replacement to verify .data assignment
|
||||
if not _first_logged:
|
||||
_old_id_a = id(lora_A.weight)
|
||||
_old_ptr_a = lora_A.weight.data_ptr()
|
||||
|
||||
# Use .data assignment to keep the same Parameter objects.
|
||||
# This preserves optimizer references (which point to these objects).
|
||||
# Creating new nn.Parameter() would break the optimizer link.
|
||||
lora_A.weight.data = buffers[key_a][expert_idx]
|
||||
lora_B.weight.data = buffers[key_b][expert_idx]
|
||||
lora_A.weight.requires_grad_(True)
|
||||
lora_B.weight.requires_grad_(True)
|
||||
lora_A.weight.grad = grad_buffers["grad_" + key_a][expert_idx]
|
||||
lora_B.weight.grad = grad_buffers["grad_" + key_b][expert_idx]
|
||||
|
||||
if not _first_logged:
|
||||
_new_id_a = id(lora_A.weight)
|
||||
_new_ptr_a = lora_A.weight.data_ptr()
|
||||
_buf_ptr_a = buffers[key_a][expert_idx].data_ptr()
|
||||
_has_grad = lora_A.weight.grad is not None
|
||||
logger.info(
|
||||
"[_replace_peft_weights_with_views] first param: "
|
||||
"id %s->%s (same=%s) data_ptr %s->%s buf_ptr=%s (match=%s) "
|
||||
"has_grad=%s requires_grad=%s shape=%s",
|
||||
_old_id_a, _new_id_a, _old_id_a == _new_id_a,
|
||||
_old_ptr_a, _new_ptr_a, _buf_ptr_a, _new_ptr_a == _buf_ptr_a,
|
||||
_has_grad, lora_A.weight.requires_grad, tuple(lora_A.weight.shape),
|
||||
)
|
||||
_first_logged = True
|
||||
_replaced += 1
|
||||
|
||||
logger.info("[_replace_peft_weights_with_views] replaced %d param pairs", _replaced)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Runtime LoRA Pointer Updates
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def update_kt_lora_pointers(model: nn.Module):
|
||||
"""Mark KT wrapper LoRA pointers as dirty after optimizer.step()."""
|
||||
wrappers = _find_kt_wrappers(model)
|
||||
|
||||
if wrappers:
|
||||
for wrapper in wrappers:
|
||||
wrapper._lora_pointers_dirty = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cross-Rank Gradient Synchronization
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def sync_kt_lora_gradients(model: nn.Module) -> None:
|
||||
"""
|
||||
Synchronize KT-managed LoRA gradients across ranks.
|
||||
|
||||
KT computes expert LoRA gradients only on rank 0 (gather/scatter path). This function broadcasts the
|
||||
per-layer contiguous grad buffers from rank 0 to all ranks so that:
|
||||
- gradient clipping sees identical grads on every rank
|
||||
- optimizer.step() applies identical updates
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
if not (dist.is_initialized() and dist.get_world_size() > 1):
|
||||
return
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
if world_size <= 1:
|
||||
return
|
||||
|
||||
params = get_kt_lora_params(model)
|
||||
if not params:
|
||||
return
|
||||
|
||||
for param in params:
|
||||
if param.grad is not None:
|
||||
# Move grad to the same device as the parameter for all-reduce
|
||||
# Then move back to CPU
|
||||
original_device = param.grad.device
|
||||
if original_device.type == "cpu":
|
||||
# All-reduce on CPU might be slow; consider using a GPU buffer
|
||||
grad_gpu = param.grad.cuda()
|
||||
dist.all_reduce(grad_gpu, op=dist.ReduceOp.SUM)
|
||||
grad_gpu.div_(world_size)
|
||||
param.grad.copy_(grad_gpu.cpu())
|
||||
else:
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
|
||||
param.grad.div_(world_size)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Checkpoint Save/Load
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def save_lora_experts_to_adapter(model: nn.Module, output_dir: str) -> None:
|
||||
"""
|
||||
Save LoRA Experts weights to adapter file by merging with existing Attention LoRA.
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
wrappers = _find_kt_wrappers(model) or []
|
||||
if not wrappers:
|
||||
logger.warning("No KT wrappers found, skipping LoRA Experts saving")
|
||||
return
|
||||
|
||||
adapter_file = os.path.join(output_dir, "adapter_model.safetensors")
|
||||
if not os.path.exists(adapter_file):
|
||||
adapter_file_bin = os.path.join(output_dir, "adapter_model.bin")
|
||||
if os.path.exists(adapter_file_bin):
|
||||
state_dict = torch.load(adapter_file_bin, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
logger.warning(f"No existing adapter file found at {output_dir}, creating new one")
|
||||
state_dict = {}
|
||||
else:
|
||||
state_dict = {}
|
||||
with safe_open(adapter_file, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key)
|
||||
|
||||
lora_expert_count = 0
|
||||
for wrapper in wrappers:
|
||||
if wrapper.lora_experts is None:
|
||||
continue
|
||||
|
||||
layer_idx = wrapper.layer_idx
|
||||
for expert_idx, expert in enumerate(wrapper.lora_experts.experts):
|
||||
base_key = f"base_model.model.model.layers.{layer_idx}.mlp.lora_experts.{expert_idx}"
|
||||
state_dict[f"{base_key}.le_gate.weight"] = expert.le_gate.weight.data.cpu().clone()
|
||||
state_dict[f"{base_key}.le_up.weight"] = expert.le_up.weight.data.cpu().clone()
|
||||
state_dict[f"{base_key}.le_down.weight"] = expert.le_down.weight.data.cpu().clone()
|
||||
lora_expert_count += 3
|
||||
|
||||
logger.debug(f"Added LoRA Experts for layer {layer_idx} ({len(wrapper.lora_experts.experts)} experts)")
|
||||
|
||||
output_file = os.path.join(output_dir, "adapter_model.safetensors")
|
||||
save_file(state_dict, output_file, metadata={"format": "pt"})
|
||||
|
||||
logger.info(
|
||||
f"Saved LoRA Experts to {output_file}: "
|
||||
f"{len(wrappers)} layers, {lora_expert_count} LoRA Expert tensors added, "
|
||||
f"{len(state_dict)} total tensors"
|
||||
)
|
||||
|
||||
|
||||
def save_kt_moe_to_adapter(model: nn.Module, output_dir: str) -> None:
|
||||
"""
|
||||
Unified function to save KT MoE weights to adapter file.
|
||||
Note: Per-expert PEFT LoRA is saved by PEFT directly, not here.
|
||||
This function only handles lora_experts (a separate feature).
|
||||
"""
|
||||
wrappers = _find_kt_wrappers(model) or []
|
||||
if not wrappers:
|
||||
logger.info("[save_kt_moe] No KT wrappers found, skipping")
|
||||
return
|
||||
|
||||
has_lora_experts = any(w.lora_experts is not None for w in wrappers)
|
||||
has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers)
|
||||
|
||||
if has_lora_experts:
|
||||
save_lora_experts_to_adapter(model, output_dir)
|
||||
|
||||
if has_fused_lora:
|
||||
_save_fused_expert_lora(wrappers, output_dir)
|
||||
|
||||
if not has_lora_experts and not has_fused_lora:
|
||||
logger.info("[save_kt_moe] No lora_experts or fused expert LoRA in KT wrappers")
|
||||
|
||||
|
||||
def _save_fused_expert_lora(wrappers: list, output_dir: str) -> None:
|
||||
"""Save fused expert LoRA params to a safetensors file."""
|
||||
from safetensors.torch import save_file
|
||||
|
||||
names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"]
|
||||
tensors = {}
|
||||
for w in wrappers:
|
||||
fused = getattr(w, "_fused_expert_lora_params", None)
|
||||
if fused is None:
|
||||
continue
|
||||
for param, name in zip(fused, names):
|
||||
key = f"layers.{w.layer_idx}.experts.{name}"
|
||||
tensors[key] = param.data.clone()
|
||||
|
||||
if tensors:
|
||||
path = os.path.join(output_dir, "fused_expert_lora.safetensors")
|
||||
save_file(tensors, path)
|
||||
logger.info(f"[save_kt_moe] Saved {len(tensors)} fused expert LoRA tensors to {path}")
|
||||
|
||||
|
||||
def _load_fused_expert_lora(wrappers: list, adapter_path: str) -> None:
|
||||
"""Load fused expert LoRA params from a safetensors file into existing wrapper buffers."""
|
||||
path = os.path.join(adapter_path, "fused_expert_lora.safetensors")
|
||||
if not os.path.isfile(path):
|
||||
logger.warning(f"No fused_expert_lora.safetensors found at {adapter_path}")
|
||||
return
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
saved = load_file(path)
|
||||
names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"]
|
||||
wrapper_map = {w.layer_idx: w for w in wrappers}
|
||||
loaded_count = 0
|
||||
|
||||
for key, tensor in saved.items():
|
||||
parts = key.split(".")
|
||||
if len(parts) != 4 or parts[0] != "layers" or parts[2] != "experts":
|
||||
logger.warning(f"Unexpected key in fused_expert_lora.safetensors: {key}")
|
||||
continue
|
||||
layer_idx = int(parts[1])
|
||||
name = parts[3]
|
||||
if name not in names:
|
||||
continue
|
||||
|
||||
wrapper = wrapper_map.get(layer_idx)
|
||||
if wrapper is None:
|
||||
continue
|
||||
fused = getattr(wrapper, "_fused_expert_lora_params", None)
|
||||
if fused is None:
|
||||
continue
|
||||
|
||||
param_idx = names.index(name)
|
||||
fused[param_idx].data.copy_(tensor)
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"[_load_fused_expert_lora] Loaded {loaded_count} tensors from {path}")
|
||||
|
||||
|
||||
def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None:
|
||||
"""
|
||||
Load LoRA Experts weights from adapter file into KT wrappers.
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
wrappers = _find_kt_wrappers(model) or []
|
||||
if not wrappers:
|
||||
logger.warning("No KT wrappers found, skipping LoRA Experts loading")
|
||||
return
|
||||
|
||||
wrapper_map = {w.layer_idx: w for w in wrappers if w.lora_experts is not None}
|
||||
if not wrapper_map:
|
||||
logger.warning("No LoRA Experts found in KT wrappers, skipping")
|
||||
return
|
||||
|
||||
# Prefer dedicated lora_experts file, fallback to adapter file
|
||||
adapter_file = os.path.join(adapter_path, "lora_experts.safetensors")
|
||||
if not os.path.exists(adapter_file):
|
||||
adapter_file = os.path.join(adapter_path, "adapter_model.safetensors")
|
||||
if not os.path.exists(adapter_file):
|
||||
adapter_file = os.path.join(adapter_path, "adapter_model.bin")
|
||||
if not os.path.exists(adapter_file):
|
||||
logger.warning(f"No lora_experts or adapter file found at {adapter_path}")
|
||||
return
|
||||
|
||||
logger.info(f"Loading LoRA Experts from {adapter_file}")
|
||||
|
||||
lora_expert_pattern = re.compile(
|
||||
r"base_model\.model\.model\.layers\.(\d+)\.mlp\.lora_experts\.(\d+)\.(le_gate|le_up|le_down)\.weight"
|
||||
)
|
||||
|
||||
layer_weights = {}
|
||||
with safe_open(adapter_file, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
match = lora_expert_pattern.match(key)
|
||||
if match:
|
||||
layer_idx = int(match.group(1))
|
||||
expert_idx = int(match.group(2))
|
||||
proj_name = match.group(3)
|
||||
layer_weights.setdefault(layer_idx, {}).setdefault(expert_idx, {})[proj_name] = f.get_tensor(key)
|
||||
|
||||
loaded_count = 0
|
||||
for layer_idx, experts_dict in layer_weights.items():
|
||||
if layer_idx not in wrapper_map:
|
||||
logger.warning(f"No LoRA Experts for layer {layer_idx}, skipping")
|
||||
continue
|
||||
|
||||
wrapper = wrapper_map[layer_idx]
|
||||
for expert_idx, proj_dict in experts_dict.items():
|
||||
if expert_idx >= len(wrapper.lora_experts.experts):
|
||||
continue
|
||||
expert = wrapper.lora_experts.experts[expert_idx]
|
||||
if "le_gate" in proj_dict:
|
||||
expert.le_gate.weight.data.copy_(proj_dict["le_gate"].to(expert.le_gate.weight.device))
|
||||
if "le_up" in proj_dict:
|
||||
expert.le_up.weight.data.copy_(proj_dict["le_up"].to(expert.le_up.weight.device))
|
||||
if "le_down" in proj_dict:
|
||||
expert.le_down.weight.data.copy_(proj_dict["le_down"].to(expert.le_down.weight.device))
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"Loaded LoRA Experts for {loaded_count} experts from {adapter_path}")
|
||||
|
||||
|
||||
def load_kt_moe_from_adapter(model: nn.Module, adapter_path: str) -> None:
|
||||
"""
|
||||
Unified function to load KT MoE weights from adapter file.
|
||||
Note: Per-expert PEFT LoRA is loaded by PEFT directly, not here.
|
||||
This function only handles lora_experts (a separate feature).
|
||||
"""
|
||||
wrappers = _find_kt_wrappers(model) or []
|
||||
if not wrappers:
|
||||
logger.warning("No KT wrappers found, skipping KT MoE loading")
|
||||
return
|
||||
|
||||
has_lora_experts = any(w.lora_experts is not None for w in wrappers)
|
||||
has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers)
|
||||
|
||||
if has_lora_experts:
|
||||
load_lora_experts_from_adapter(model, adapter_path)
|
||||
|
||||
if has_fused_lora:
|
||||
_load_fused_expert_lora(wrappers, adapter_path)
|
||||
|
||||
if not has_lora_experts and not has_fused_lora:
|
||||
logger.info("No lora_experts or fused expert LoRA in KT wrappers")
|
||||
557
kt-kernel/python/sft/weights.py
Normal file
557
kt-kernel/python/sft/weights.py
Normal file
|
|
@ -0,0 +1,557 @@
|
|||
# Weight extraction and loading utilities for SFT
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .arch import MOEArchConfig
|
||||
from .dist_utils import _maybe_zero3_gathered_parameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
safe_open = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Weight Extraction
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def extract_moe_weights(
|
||||
moe_module: nn.Module, moe_config: MOEArchConfig
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Extract MoE expert weights from the module.
|
||||
|
||||
Returns (gate_proj, up_proj, down_proj) with shape
|
||||
[expert_num, out_features, in_features].
|
||||
|
||||
Supports two formats:
|
||||
- ModuleList of Linear experts (transformers v4 style)
|
||||
- Fused Parameters (transformers v5 style): single module with
|
||||
``gate_up_proj`` [E, 2*I, H] and ``down_proj`` [E, H, I] tensors.
|
||||
"""
|
||||
from .arch import detect_fused_experts
|
||||
|
||||
experts = getattr(moe_module, moe_config.experts_attr)
|
||||
|
||||
# Fused format (transformers v5): a single nn.Module with gate_up_proj/down_proj tensors
|
||||
if detect_fused_experts(experts):
|
||||
gate_up = getattr(experts, "gate_up_proj").data
|
||||
down_fused = getattr(experts, "down_proj").data
|
||||
# gate_up_proj is [E, 2*I, H], split into gate [E, I, H] and up [E, I, H]
|
||||
intermediate = gate_up.shape[1] // 2
|
||||
gate_proj = gate_up[:, :intermediate, :].contiguous()
|
||||
up_proj = gate_up[:, intermediate:, :].contiguous()
|
||||
# down_proj is already [E, H, I]
|
||||
down_proj = down_fused.contiguous()
|
||||
return gate_proj, up_proj, down_proj
|
||||
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
|
||||
gather_params: list[torch.nn.Parameter] = []
|
||||
for expert in experts:
|
||||
for weight_name in (gate_name, up_name, down_name):
|
||||
proj = getattr(expert, weight_name, None)
|
||||
if proj is not None and hasattr(proj, "weight"):
|
||||
# Handle PEFT LoRA wrapped modules
|
||||
weight = proj.weight
|
||||
if isinstance(weight, torch.Tensor):
|
||||
gather_params.append(weight)
|
||||
elif hasattr(weight, "data"):
|
||||
gather_params.append(weight.data)
|
||||
|
||||
with _maybe_zero3_gathered_parameters(gather_params):
|
||||
gate_weights = []
|
||||
up_weights = []
|
||||
down_weights = []
|
||||
|
||||
for expert in experts:
|
||||
# Handle PEFT LoRA wrapped modules - get weight tensor properly
|
||||
gate_proj = getattr(expert, gate_name)
|
||||
up_proj_mod = getattr(expert, up_name)
|
||||
down_proj_mod = getattr(expert, down_name)
|
||||
|
||||
# Get weight tensors, handling both regular Linear and PEFT LoRA wrapped
|
||||
def get_weight_tensor(mod):
|
||||
weight = mod.weight
|
||||
if isinstance(weight, torch.Tensor):
|
||||
return weight.data
|
||||
elif hasattr(weight, "data"):
|
||||
return weight.data
|
||||
else:
|
||||
raise ValueError(f"Cannot extract weight from {type(mod)}, weight type={type(weight)}")
|
||||
|
||||
gate_weights.append(get_weight_tensor(gate_proj))
|
||||
up_weights.append(get_weight_tensor(up_proj_mod))
|
||||
down_weights.append(get_weight_tensor(down_proj_mod))
|
||||
|
||||
gate_proj = torch.stack(gate_weights, dim=0)
|
||||
up_proj = torch.stack(up_weights, dim=0)
|
||||
down_proj = torch.stack(down_weights, dim=0)
|
||||
|
||||
return gate_proj, up_proj, down_proj
|
||||
|
||||
|
||||
def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchConfig) -> None:
|
||||
"""
|
||||
Clear original expert weights to free memory after KT weights are loaded.
|
||||
"""
|
||||
from .arch import detect_fused_experts
|
||||
|
||||
experts = getattr(moe_module, moe_config.experts_attr, None)
|
||||
if experts is None:
|
||||
return
|
||||
|
||||
# Fused format: replace gate_up_proj/down_proj tensors with zero-storage placeholders
|
||||
if detect_fused_experts(experts):
|
||||
for name in ("gate_up_proj", "down_proj"):
|
||||
param = getattr(experts, name, None)
|
||||
if not isinstance(param, torch.nn.Parameter):
|
||||
continue
|
||||
original_dtype = param.dtype
|
||||
tiny_storage = torch.UntypedStorage(1, device="cpu")
|
||||
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
|
||||
tiny_storage, storage_offset=0, size=param.shape,
|
||||
stride=[0] * len(param.shape),
|
||||
)
|
||||
experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False)
|
||||
return
|
||||
|
||||
def _iter_weight_params():
|
||||
for expert in experts:
|
||||
for weight_name in moe_config.weight_names:
|
||||
proj = getattr(expert, weight_name, None)
|
||||
if proj is None or not hasattr(proj, "weight"):
|
||||
continue
|
||||
|
||||
parametrizations = getattr(proj, "parametrizations", None)
|
||||
parametrized_weight = getattr(parametrizations, "weight", None) if parametrizations is not None else None
|
||||
if parametrized_weight is not None:
|
||||
original = getattr(parametrized_weight, "original", None)
|
||||
if isinstance(original, torch.nn.Parameter):
|
||||
yield proj, parametrized_weight, "original", original
|
||||
continue
|
||||
|
||||
direct_weight = getattr(proj, "_parameters", {}).get("weight")
|
||||
if isinstance(direct_weight, torch.nn.Parameter):
|
||||
yield proj, proj, "weight", direct_weight
|
||||
continue
|
||||
|
||||
# Fallback: `weight` can be a non-settable property (e.g. parametrizations) or a non-Parameter.
|
||||
weight_attr = getattr(proj, "weight", None)
|
||||
if isinstance(weight_attr, torch.nn.Parameter):
|
||||
yield proj, proj, "weight", weight_attr
|
||||
|
||||
gather_params: list[torch.nn.Parameter] = []
|
||||
for _, _, _, weight_param in _iter_weight_params():
|
||||
gather_params.append(weight_param)
|
||||
|
||||
replaced_count = 0
|
||||
|
||||
with _maybe_zero3_gathered_parameters(gather_params):
|
||||
for proj, container, param_name, weight_param in _iter_weight_params():
|
||||
original_dtype = weight_param.dtype
|
||||
|
||||
# Create a CPU tensor with the correct shape but NO physical memory.
|
||||
# torch.empty(shape, device="cpu") unfortunately touches pages via the
|
||||
# allocator, consuming real RSS. Instead, allocate a 1-byte storage and
|
||||
# use set_ to give it the original shape with zero strides. The tensor
|
||||
# is "valid" (correct dtype, device, shape) so PEFT can discover
|
||||
# in/out features, but its storage is essentially zero-cost.
|
||||
# NOTE: reading element values from this tensor is undefined -- it is
|
||||
# only used for shape/dtype discovery by PEFT.
|
||||
tiny_storage = torch.UntypedStorage(1, device="cpu")
|
||||
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
|
||||
tiny_storage, storage_offset=0, size=weight_param.shape,
|
||||
stride=[0] * len(weight_param.shape),
|
||||
)
|
||||
new_param = nn.Parameter(fake_tensor, requires_grad=False)
|
||||
replaced_count += 1
|
||||
|
||||
# Avoid `KeyError: attribute 'weight' already exists` for parametrized modules
|
||||
# where `weight` is a property and the real parameter lives elsewhere.
|
||||
container_params = getattr(container, "_parameters", {})
|
||||
if isinstance(container_params, dict) and param_name in container_params:
|
||||
container_params[param_name] = new_param
|
||||
continue
|
||||
|
||||
if hasattr(container, param_name):
|
||||
logger.debug(
|
||||
f"Skipping clearing expert weight {type(proj).__name__}.{param_name}: "
|
||||
"attribute exists but is not a registered parameter."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
setattr(container, param_name, new_param)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to clear expert weight {type(proj).__name__}.{param_name}: {exc}"
|
||||
)
|
||||
|
||||
logger.info(f"Replaced {replaced_count} expert weight params")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# kt_weight_path Loading Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class INT8ExpertWeights:
|
||||
"""Container for INT8 expert weights with scales."""
|
||||
|
||||
gate_proj: torch.Tensor
|
||||
gate_scale: torch.Tensor
|
||||
up_proj: torch.Tensor
|
||||
up_scale: torch.Tensor
|
||||
down_proj: torch.Tensor
|
||||
down_scale: torch.Tensor
|
||||
|
||||
|
||||
def _find_safetensor_files(kt_weight_path: str) -> list[str]:
|
||||
if not os.path.isdir(kt_weight_path):
|
||||
raise FileNotFoundError(f"kt_weight_path directory not found: {kt_weight_path}")
|
||||
|
||||
safetensor_files = []
|
||||
for file in sorted(os.listdir(kt_weight_path)):
|
||||
if file.endswith(".safetensors"):
|
||||
safetensor_files.append(os.path.join(kt_weight_path, file))
|
||||
|
||||
if not safetensor_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {kt_weight_path}")
|
||||
|
||||
return safetensor_files
|
||||
|
||||
|
||||
def _load_kt_weight_index(kt_weight_path: str) -> dict[str, str]:
|
||||
if not SAFETENSORS_AVAILABLE:
|
||||
raise ImportError("safetensors is required for loading kt_weight_path")
|
||||
|
||||
index = {}
|
||||
safetensor_files = _find_safetensor_files(kt_weight_path)
|
||||
|
||||
for file_path in safetensor_files:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
index[key] = file_path
|
||||
|
||||
logger.info(f"Indexed {len(index)} tensors from {len(safetensor_files)} safetensors files")
|
||||
return index
|
||||
|
||||
|
||||
def _dequant_fp8_experts(weights: list[torch.Tensor], scales: list[torch.Tensor | None], block_size: tuple[int, int]) -> torch.Tensor:
|
||||
"""Dequantize a list of FP8 expert weights and stack them (batched, vectorized).
|
||||
|
||||
Args:
|
||||
weights: list of [out, in] float8_e4m3fn tensors (one per expert)
|
||||
scales: list of [out//bs_m, in//bs_n] scale_inv tensors (one per expert, may be None)
|
||||
block_size: (bs_m, bs_n)
|
||||
|
||||
Returns:
|
||||
Stacked BF16 tensor of shape [num_experts, out, in]
|
||||
"""
|
||||
has_scales = scales[0] is not None
|
||||
if not has_scales:
|
||||
return torch.stack(weights, dim=0).to(torch.bfloat16).cpu().contiguous()
|
||||
|
||||
bs_m, bs_n = block_size
|
||||
n = len(weights)
|
||||
out_features, in_features = weights[0].shape
|
||||
|
||||
# Stack all experts: [N, out, in] fp8 -> reshape to blocks -> bf16
|
||||
w = torch.stack(weights, dim=0) # [N, out, in] fp8
|
||||
w = w.reshape(n, out_features // bs_m, bs_m, in_features // bs_n, bs_n)
|
||||
w = w.to(torch.bfloat16)
|
||||
|
||||
# Stack all scales: [N, out//bs_m, in//bs_n] -> bf16, broadcast multiply
|
||||
s = torch.stack(scales, dim=0).to(torch.bfloat16) # [N, out//bs_m, in//bs_n]
|
||||
w = w * s[:, :, None, :, None]
|
||||
|
||||
return w.reshape(n, out_features, in_features).contiguous()
|
||||
|
||||
|
||||
def load_experts_from_checkpoint_files(
|
||||
checkpoint_files: list[str],
|
||||
sharded_metadata: dict | None,
|
||||
layers_prefix: str,
|
||||
moe_config: MOEArchConfig,
|
||||
layer_idx: int,
|
||||
block_size: tuple[int, int] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if not SAFETENSORS_AVAILABLE:
|
||||
raise ImportError("safetensors is required for loading experts from checkpoint files")
|
||||
|
||||
if not checkpoint_files:
|
||||
raise FileNotFoundError("checkpoint_files is empty")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
weight_map = None
|
||||
base_dir = os.path.dirname(checkpoint_files[0])
|
||||
if sharded_metadata is not None:
|
||||
weight_map = sharded_metadata.get("weight_map", None)
|
||||
|
||||
gate_name, up_name, down_name = moe_config.weight_names
|
||||
experts_prefix = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}"
|
||||
fused_gate_up_key = f"{experts_prefix}.gate_up_proj"
|
||||
fused_down_key = f"{experts_prefix}.down_proj"
|
||||
is_fused = weight_map is not None and fused_gate_up_key in weight_map
|
||||
|
||||
if is_fused:
|
||||
keys = [fused_gate_up_key, fused_down_key]
|
||||
else:
|
||||
keys = []
|
||||
for expert_idx in range(moe_config.expert_num):
|
||||
base = f"{experts_prefix}.{expert_idx}"
|
||||
keys.append(f"{base}.{gate_name}.weight")
|
||||
keys.append(f"{base}.{gate_name}.weight_scale_inv")
|
||||
keys.append(f"{base}.{up_name}.weight")
|
||||
keys.append(f"{base}.{up_name}.weight_scale_inv")
|
||||
keys.append(f"{base}.{down_name}.weight")
|
||||
keys.append(f"{base}.{down_name}.weight_scale_inv")
|
||||
|
||||
keys_by_file: dict[str, list[str]] = {}
|
||||
mapped_count = 0
|
||||
unmapped_count = 0
|
||||
for key in keys:
|
||||
if weight_map is not None:
|
||||
filename = weight_map.get(key)
|
||||
if filename is None:
|
||||
unmapped_count += 1
|
||||
continue
|
||||
mapped_count += 1
|
||||
file_path = os.path.join(base_dir, filename)
|
||||
else:
|
||||
file_path = checkpoint_files[0]
|
||||
keys_by_file.setdefault(file_path, []).append(key)
|
||||
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: key mapping done in {time.time()-t0:.1f}s — "
|
||||
f"total_keys={len(keys)}, mapped={mapped_count}, unmapped={unmapped_count}, "
|
||||
f"files_to_open={len(keys_by_file)}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
t1 = time.time()
|
||||
tensor_map: dict[str, torch.Tensor] = {}
|
||||
for file_idx, (file_path, file_keys) in enumerate(keys_by_file.items()):
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
available_keys = set(f.keys())
|
||||
for key in file_keys:
|
||||
if key in available_keys:
|
||||
tensor_map[key] = f.get_tensor(key)
|
||||
if file_idx == 0:
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: first file loaded ({os.path.basename(file_path)}, "
|
||||
f"{len(file_keys)} keys) in {time.time()-t1:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: all files loaded in {time.time()-t1:.1f}s — "
|
||||
f"tensor_map has {len(tensor_map)} tensors",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
t2 = time.time()
|
||||
if is_fused:
|
||||
gate_up = tensor_map.get(fused_gate_up_key)
|
||||
down = tensor_map.get(fused_down_key)
|
||||
if gate_up is None or down is None:
|
||||
raise FileNotFoundError(f"Missing fused expert weights for layer {layer_idx}")
|
||||
gate_up = gate_up.cpu().to(torch.bfloat16).contiguous()
|
||||
I = gate_up.shape[1] // 2
|
||||
gate_proj = gate_up[:, :I, :].contiguous()
|
||||
up_proj = gate_up[:, I:, :].contiguous()
|
||||
down_proj = down.cpu().to(torch.bfloat16).contiguous()
|
||||
del gate_up
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: fused expert format — "
|
||||
f"split gate_up_proj [{gate_proj.shape}] + down [{down_proj.shape}]",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, "
|
||||
f"shape={gate_proj.shape}, dequant=0.0s, total={time.time()-t0:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
return gate_proj, up_proj, down_proj
|
||||
|
||||
gate_weights = []
|
||||
up_weights = []
|
||||
down_weights = []
|
||||
gate_scales = []
|
||||
up_scales = []
|
||||
down_scales = []
|
||||
for expert_idx in range(moe_config.expert_num):
|
||||
base = f"{experts_prefix}.{expert_idx}"
|
||||
gate_key = f"{base}.{gate_name}.weight"
|
||||
up_key = f"{base}.{up_name}.weight"
|
||||
down_key = f"{base}.{down_name}.weight"
|
||||
if gate_key not in tensor_map or up_key not in tensor_map or down_key not in tensor_map:
|
||||
raise FileNotFoundError(f"Missing expert weights for layer {layer_idx}, expert {expert_idx}")
|
||||
gate_weights.append(tensor_map[gate_key])
|
||||
up_weights.append(tensor_map[up_key])
|
||||
down_weights.append(tensor_map[down_key])
|
||||
gate_scales.append(tensor_map.get(f"{base}.{gate_name}.weight_scale_inv"))
|
||||
up_scales.append(tensor_map.get(f"{base}.{up_name}.weight_scale_inv"))
|
||||
down_scales.append(tensor_map.get(f"{base}.{down_name}.weight_scale_inv"))
|
||||
|
||||
# Check if weights are FP8 and need dequantization
|
||||
t2 = time.time()
|
||||
is_fp8 = gate_weights[0].dtype == torch.float8_e4m3fn
|
||||
if is_fp8:
|
||||
if block_size is None:
|
||||
block_size = (128, 128)
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: FP8 expert weights detected, "
|
||||
f"dequantizing with block_size={block_size} "
|
||||
f"(has_scales={gate_scales[0] is not None})",
|
||||
flush=True,
|
||||
)
|
||||
gate_proj = _dequant_fp8_experts(gate_weights, gate_scales, block_size)
|
||||
up_proj = _dequant_fp8_experts(up_weights, up_scales, block_size)
|
||||
down_proj = _dequant_fp8_experts(down_weights, down_scales, block_size)
|
||||
else:
|
||||
gate_proj = torch.stack(gate_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
|
||||
up_proj = torch.stack(up_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
|
||||
down_proj = torch.stack(down_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
|
||||
|
||||
print(
|
||||
f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, shape={gate_proj.shape}, "
|
||||
f"dequant={time.time()-t2:.1f}s, total={time.time()-t0:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
return gate_proj, up_proj, down_proj
|
||||
|
||||
|
||||
def load_experts_from_kt_weight_path(
|
||||
kt_weight_path: str,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
) -> INT8ExpertWeights:
|
||||
"""Load INT8 preprocessed expert weights from kt_weight_path for a specific layer."""
|
||||
if not SAFETENSORS_AVAILABLE:
|
||||
raise ImportError("safetensors is required for loading kt_weight_path")
|
||||
|
||||
index = _load_kt_weight_index(kt_weight_path)
|
||||
|
||||
numa_count = 0
|
||||
test_key_prefix = f"blk.{layer_idx}.ffn_gate_exps.0.numa."
|
||||
for key in index.keys():
|
||||
if key.startswith(test_key_prefix) and key.endswith(".weight"):
|
||||
numa_idx = int(key.split("numa.")[1].split(".")[0])
|
||||
numa_count = max(numa_count, numa_idx + 1)
|
||||
|
||||
if numa_count == 0:
|
||||
raise FileNotFoundError(
|
||||
f"No weights found for layer {layer_idx} in {kt_weight_path}. "
|
||||
f"Expected keys like 'blk.{layer_idx}.ffn_gate_exps.0.numa.0.weight'"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Loading INT8 weights for layer {layer_idx}: {num_experts} experts, {numa_count} NUMA partitions"
|
||||
)
|
||||
|
||||
gate_weights_list = []
|
||||
gate_scales_list = []
|
||||
up_weights_list = []
|
||||
up_scales_list = []
|
||||
down_weights_list = []
|
||||
down_scales_list = []
|
||||
|
||||
for expert_idx in range(num_experts):
|
||||
gate_w_parts = []
|
||||
gate_s_parts = []
|
||||
for numa_idx in range(numa_count):
|
||||
w_key = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.numa.{numa_idx}.weight"
|
||||
s_key = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.numa.{numa_idx}.scale"
|
||||
|
||||
if w_key not in index:
|
||||
raise FileNotFoundError(f"Weight key not found: {w_key}")
|
||||
|
||||
with safe_open(index[w_key], framework="pt") as f:
|
||||
gate_w_parts.append(f.get_tensor(w_key))
|
||||
gate_s_parts.append(f.get_tensor(s_key))
|
||||
|
||||
gate_w = torch.cat(gate_w_parts, dim=0)
|
||||
gate_s = torch.cat(gate_s_parts, dim=0)
|
||||
gate_w = gate_w.view(intermediate_size, hidden_size)
|
||||
|
||||
gate_weights_list.append(gate_w)
|
||||
gate_scales_list.append(gate_s)
|
||||
|
||||
up_w_parts = []
|
||||
up_s_parts = []
|
||||
for numa_idx in range(numa_count):
|
||||
w_key = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.numa.{numa_idx}.weight"
|
||||
s_key = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.numa.{numa_idx}.scale"
|
||||
|
||||
if w_key not in index:
|
||||
raise FileNotFoundError(f"Weight key not found: {w_key}")
|
||||
|
||||
with safe_open(index[w_key], framework="pt") as f:
|
||||
up_w_parts.append(f.get_tensor(w_key))
|
||||
up_s_parts.append(f.get_tensor(s_key))
|
||||
|
||||
up_w = torch.cat(up_w_parts, dim=0)
|
||||
up_s = torch.cat(up_s_parts, dim=0)
|
||||
up_w = up_w.view(intermediate_size, hidden_size)
|
||||
|
||||
up_weights_list.append(up_w)
|
||||
up_scales_list.append(up_s)
|
||||
|
||||
down_w_parts = []
|
||||
down_s_parts = []
|
||||
for numa_idx in range(numa_count):
|
||||
w_key = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.numa.{numa_idx}.weight"
|
||||
s_key = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.numa.{numa_idx}.scale"
|
||||
|
||||
if w_key not in index:
|
||||
raise FileNotFoundError(f"Weight key not found: {w_key}")
|
||||
|
||||
with safe_open(index[w_key], framework="pt") as f:
|
||||
down_w_parts.append(f.get_tensor(w_key))
|
||||
down_s_parts.append(f.get_tensor(s_key))
|
||||
|
||||
down_w = torch.cat(down_w_parts, dim=0)
|
||||
down_s = torch.cat(down_s_parts, dim=0)
|
||||
down_w = down_w.view(hidden_size, intermediate_size)
|
||||
|
||||
down_weights_list.append(down_w)
|
||||
down_scales_list.append(down_s)
|
||||
|
||||
gate_proj = torch.stack(gate_weights_list, dim=0)
|
||||
gate_scale = torch.stack(gate_scales_list, dim=0)
|
||||
up_proj = torch.stack(up_weights_list, dim=0)
|
||||
up_scale = torch.stack(up_scales_list, dim=0)
|
||||
down_proj = torch.stack(down_weights_list, dim=0)
|
||||
down_scale = torch.stack(down_scales_list, dim=0)
|
||||
|
||||
return INT8ExpertWeights(
|
||||
gate_proj=gate_proj,
|
||||
gate_scale=gate_scale,
|
||||
up_proj=up_proj,
|
||||
up_scale=up_scale,
|
||||
down_proj=down_proj,
|
||||
down_scale=down_scale,
|
||||
)
|
||||
588
kt-kernel/python/sft/wrapper.py
Normal file
588
kt-kernel/python/sft/wrapper.py
Normal file
|
|
@ -0,0 +1,588 @@
|
|||
# Model wrapping entry points for SFT
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import importlib.util as _u
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .arch import (
|
||||
KTAMXConfigError,
|
||||
KTAMXNotAvailableError,
|
||||
MOEArchConfig,
|
||||
_get_layers_prefix,
|
||||
_get_model_container_and_layers,
|
||||
get_moe_arch_config,
|
||||
get_moe_module,
|
||||
)
|
||||
from .layer import KTMoELayerWrapper
|
||||
from .lora import LoRAExperts
|
||||
from .weights import (
|
||||
_clear_original_expert_weights,
|
||||
extract_moe_weights,
|
||||
load_experts_from_checkpoint_files,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KT_KERNEL_AVAILABLE = _u.find_spec("kt_kernel") is not None
|
||||
|
||||
if KT_KERNEL_AVAILABLE:
|
||||
try:
|
||||
from kt_kernel.experts import KTMoEWrapper
|
||||
except Exception:
|
||||
KTMoEWrapper = None
|
||||
KT_KERNEL_AVAILABLE = False
|
||||
else:
|
||||
KTMoEWrapper = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Device-map builders
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_kt_config(kt_plugin: Any):
|
||||
"""Extract KTConfig from a KTransformersPlugin or compatible object.
|
||||
|
||||
KTConfig field names use kt_ prefix, matching the dict keys in
|
||||
HfTrainerKTConfig exactly — no name-mapping needed.
|
||||
"""
|
||||
from .config import KTConfig
|
||||
|
||||
if isinstance(kt_plugin, KTConfig):
|
||||
return kt_plugin
|
||||
|
||||
kt_config = getattr(kt_plugin, "kt_config", None)
|
||||
if kt_config is not None and isinstance(kt_config, KTConfig):
|
||||
return kt_config
|
||||
|
||||
return KTConfig.from_object(kt_plugin)
|
||||
|
||||
|
||||
def build_kt_device_map(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]:
|
||||
"""
|
||||
Build device_map for KT model loading with hybrid GPU/CPU expert placement.
|
||||
"""
|
||||
moe_config = get_moe_arch_config(config)
|
||||
layers_prefix = _get_layers_prefix(config)
|
||||
num_layers = config.num_hidden_layers
|
||||
num_experts = moe_config.expert_num
|
||||
cfg = _get_kt_config(kt_plugin)
|
||||
num_gpu_experts = getattr(cfg, "kt_num_gpu_experts", 0) or 0
|
||||
|
||||
device_map: dict[str, str | int] = {}
|
||||
|
||||
device_map["model.embed_tokens"] = device
|
||||
device_map["model.norm"] = device
|
||||
device_map["lm_head"] = device
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
layer_prefix = f"{layers_prefix}.{layer_idx}"
|
||||
device_map[layer_prefix] = device
|
||||
moe_prefix = f"{layer_prefix}.{moe_config.moe_layer_attr}"
|
||||
|
||||
for expert_idx in range(num_experts):
|
||||
expert_key = f"{moe_prefix}.{moe_config.experts_attr}.{expert_idx}"
|
||||
if expert_idx < num_gpu_experts:
|
||||
device_map[expert_key] = device
|
||||
else:
|
||||
device_map[expert_key] = "cpu"
|
||||
|
||||
logger.info(
|
||||
f"Built KT device_map: {num_gpu_experts} GPU experts, {num_experts - num_gpu_experts} CPU experts"
|
||||
)
|
||||
|
||||
return device_map
|
||||
|
||||
|
||||
def build_kt_device_map_simplified(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]:
|
||||
"""
|
||||
Simplified device_map builder: map full layers to GPU, override routed experts to CPU.
|
||||
"""
|
||||
moe_config = get_moe_arch_config(config)
|
||||
layers_prefix = _get_layers_prefix(config)
|
||||
num_layers = config.num_hidden_layers
|
||||
cfg = _get_kt_config(kt_plugin)
|
||||
num_gpu_experts = getattr(cfg, "kt_num_gpu_experts", 0) or 0
|
||||
|
||||
device_map: dict[str, str | int] = {}
|
||||
|
||||
device_map["model.embed_tokens"] = device
|
||||
device_map["model.norm"] = device
|
||||
device_map["lm_head"] = device
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
layer_prefix = f"{layers_prefix}.{layer_idx}"
|
||||
device_map[layer_prefix] = device
|
||||
|
||||
experts_prefix = f"{layer_prefix}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}"
|
||||
|
||||
if num_gpu_experts == 0:
|
||||
device_map[experts_prefix] = "cpu"
|
||||
else:
|
||||
return build_kt_device_map(config, kt_plugin, device=device)
|
||||
|
||||
logger.info("Built simplified KT device_map: all layers on GPU, routed experts on CPU")
|
||||
return device_map
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MoE layer wrapping
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KTMoELayerWrapper]:
|
||||
"""
|
||||
Replace model's MoE layers with KTMoEWrapper-based wrappers.
|
||||
|
||||
Loads expert weights into the C++ KT kernel. No LoRA initialization ---
|
||||
LoRA is handled by PEFT and later adapted via kt_adapt_peft_lora().
|
||||
Only rank 0 initializes KT kernel and loads weights.
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
if not KT_KERNEL_AVAILABLE:
|
||||
raise KTAMXNotAvailableError("kt_kernel not found. Please install kt_kernel to enable KT MoE support.")
|
||||
|
||||
# Only rank 0 should initialize KT and load weights
|
||||
is_rank_0 = True
|
||||
if dist.is_initialized():
|
||||
is_rank_0 = dist.get_rank() == 0
|
||||
|
||||
moe_config = get_moe_arch_config(model.config)
|
||||
_text_cfg = getattr(model.config, "text_config", model.config)
|
||||
hidden_size = _text_cfg.hidden_size
|
||||
|
||||
cfg = _get_kt_config(kt_plugin)
|
||||
|
||||
# Read lora_rank/lora_alpha for C++ wrapper initialization (buffer allocation only)
|
||||
lora_rank = getattr(cfg, "kt_lora_rank", 1) or 1
|
||||
lora_alpha = getattr(cfg, "kt_lora_alpha", 1.0) or 1.0
|
||||
|
||||
# Read LoRA Experts configuration
|
||||
_raw_le = getattr(cfg, "kt_use_lora_experts", None)
|
||||
use_lora_experts = bool(_raw_le) if _raw_le is not None else False
|
||||
lora_expert_num = getattr(cfg, "kt_lora_expert_num", 2) or 2
|
||||
lora_expert_intermediate_size = getattr(cfg, "kt_lora_expert_intermediate_size", 1024) or 1024
|
||||
|
||||
if is_rank_0:
|
||||
logger.info(
|
||||
f"LoRA Experts config: use_lora_experts={use_lora_experts}, "
|
||||
f"num={lora_expert_num}, intermediate_size={lora_expert_intermediate_size}"
|
||||
)
|
||||
|
||||
wrappers: list[KTMoELayerWrapper] = []
|
||||
moe_layer_count = 0
|
||||
|
||||
kt_backend_map = {
|
||||
"AMXBF16": "AMXBF16_SFT",
|
||||
"AMXINT8": "AMXINT8_SFT",
|
||||
"AMXINT4": "AMXINT4_SFT",
|
||||
"AMXBF16_SkipLoRA": "AMXBF16_SFT_SkipLoRA",
|
||||
"AMXINT8_SkipLoRA": "AMXINT8_SFT_SkipLoRA",
|
||||
"AMXINT4_SkipLoRA": "AMXINT4_SFT_SkipLoRA",
|
||||
}
|
||||
# Build case-insensitive lookup to handle common typos like "SkipLora" vs "SkipLoRA"
|
||||
_kt_backend_map_lower = {k.lower(): v for k, v in kt_backend_map.items()}
|
||||
kt_backend = getattr(cfg, "kt_backend", "AMXBF16")
|
||||
kt_method = kt_backend_map.get(kt_backend) or _kt_backend_map_lower.get(kt_backend.lower(), "AMXBF16_SFT")
|
||||
if kt_method != kt_backend_map.get(kt_backend):
|
||||
logger.warning(
|
||||
f"kt_backend '{kt_backend}' matched via case-insensitive lookup -> '{kt_method}'. "
|
||||
f"Please use the exact name from: {list(kt_backend_map.keys())}"
|
||||
)
|
||||
|
||||
if "SkipLoRA" in kt_method:
|
||||
logger.info(f"Using SkipLoRA backend: {kt_method} (MoE LoRA gradients will be skipped)")
|
||||
|
||||
threadpool_count = getattr(cfg, "kt_threadpool_count", 1) if getattr(cfg, "kt_tp_enabled", False) else 1
|
||||
|
||||
kt_weight_path = getattr(cfg, "kt_weight_path", None)
|
||||
use_kt_weight_path = kt_weight_path is not None
|
||||
if use_kt_weight_path:
|
||||
logger.info(f"Loading INT8 weights from kt_weight_path: {kt_weight_path}")
|
||||
|
||||
checkpoint_files = getattr(cfg, "kt_checkpoint_files", None)
|
||||
sharded_metadata = getattr(cfg, "kt_sharded_metadata", None)
|
||||
|
||||
# When kt_expert_checkpoint_path is set, always resolve from it (overrides any existing
|
||||
# checkpoint_files which may come from AttnOnlyBf16 and lack expert weights).
|
||||
kt_expert_checkpoint_path = getattr(cfg, "kt_expert_checkpoint_path", None)
|
||||
if kt_expert_checkpoint_path:
|
||||
logger.info(f"Resolving expert checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}")
|
||||
resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=kt_expert_checkpoint_path)
|
||||
if resolved_files and all(f.endswith(".safetensors") for f in resolved_files):
|
||||
checkpoint_files = resolved_files
|
||||
sharded_metadata = resolved_meta
|
||||
cfg.kt_checkpoint_files = checkpoint_files
|
||||
cfg.kt_sharded_metadata = sharded_metadata
|
||||
logger.info(f"Resolved {len(checkpoint_files)} checkpoint files from kt_expert_checkpoint_path")
|
||||
else:
|
||||
logger.warning(f"Failed to resolve checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}")
|
||||
|
||||
use_checkpoint_files = bool(checkpoint_files) and not use_kt_weight_path
|
||||
|
||||
logger.debug(
|
||||
f"Weight source: kt_weight_path={kt_weight_path!r}, "
|
||||
f"kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}, "
|
||||
f"checkpoint_files count={len(checkpoint_files) if checkpoint_files else 0}, "
|
||||
f"use_kt_weight_path={use_kt_weight_path}, use_checkpoint_files={use_checkpoint_files}"
|
||||
)
|
||||
|
||||
if use_checkpoint_files:
|
||||
logger.info("Loading expert weights from checkpoint files (online conversion).")
|
||||
elif use_kt_weight_path and bool(checkpoint_files):
|
||||
logger.info("BF16 checkpoint files available for backward gradient computation.")
|
||||
elif (not use_kt_weight_path) and bool(getattr(cfg, "kt_skip_expert_loading", False)):
|
||||
# If HF expert weights were skipped during `from_pretrained`, we must source expert weights externally.
|
||||
model_name_or_path = getattr(getattr(model, "config", None), "name_or_path", None)
|
||||
if model_name_or_path:
|
||||
resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=model_name_or_path)
|
||||
if resolved_files and all(f.endswith(".safetensors") for f in resolved_files):
|
||||
checkpoint_files = resolved_files
|
||||
sharded_metadata = resolved_meta
|
||||
cfg.kt_checkpoint_files = checkpoint_files
|
||||
cfg.kt_sharded_metadata = sharded_metadata
|
||||
use_checkpoint_files = True
|
||||
logger.info("KT skip_expert_loading enabled; using checkpoint files for online expert loading.")
|
||||
|
||||
if not use_checkpoint_files:
|
||||
raise KTAMXConfigError(
|
||||
"KT skip_expert_loading is enabled but no `kt_weight_path` was provided and no safetensors checkpoint "
|
||||
"files could be resolved for on-the-fly expert loading."
|
||||
)
|
||||
|
||||
import torch.distributed as _dist
|
||||
_rank = _dist.get_rank() if _dist.is_initialized() else 0
|
||||
|
||||
model_container, layers = _get_model_container_and_layers(model, purpose="wrapping")
|
||||
logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}")
|
||||
|
||||
from .arch import detect_fused_experts as _detect_fused
|
||||
|
||||
for layer_idx, layer in enumerate(layers):
|
||||
moe_module = get_moe_module(layer, moe_config)
|
||||
if moe_module is None:
|
||||
continue
|
||||
|
||||
_layer_experts = getattr(moe_module, moe_config.experts_attr, None)
|
||||
_layer_is_fused = _detect_fused(_layer_experts)
|
||||
|
||||
logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method}, fused={_layer_is_fused})")
|
||||
|
||||
# Only rank 0 loads weights and initializes KT kernel
|
||||
gate_proj, up_proj, down_proj = None, None, None
|
||||
wrapper = None
|
||||
|
||||
if is_rank_0:
|
||||
# Get block_size from quantization_config if available (for FP8 dequant)
|
||||
_quant_cfg = getattr(model.config, "quantization_config", None)
|
||||
_block_size = None
|
||||
if _quant_cfg is not None:
|
||||
_block_size = getattr(_quant_cfg, "weight_block_size", None)
|
||||
|
||||
if use_kt_weight_path:
|
||||
logger.debug(f"Layer {layer_idx}: forward + backward from kt_weight_path (.kt files)")
|
||||
elif use_checkpoint_files:
|
||||
layers_prefix = _get_layers_prefix(model.config)
|
||||
gate_proj, up_proj, down_proj = load_experts_from_checkpoint_files(
|
||||
checkpoint_files=checkpoint_files,
|
||||
sharded_metadata=sharded_metadata,
|
||||
layers_prefix=layers_prefix,
|
||||
moe_config=moe_config,
|
||||
layer_idx=layer_idx,
|
||||
block_size=_block_size,
|
||||
)
|
||||
else:
|
||||
gate_proj, up_proj, down_proj = extract_moe_weights(moe_module, moe_config)
|
||||
gate_proj = gate_proj.cpu().to(torch.bfloat16).contiguous()
|
||||
up_proj = up_proj.cpu().to(torch.bfloat16).contiguous()
|
||||
down_proj = down_proj.cpu().to(torch.bfloat16).contiguous()
|
||||
|
||||
chunked_prefill_size = getattr(cfg, "kt_model_max_length", None)
|
||||
if chunked_prefill_size is None:
|
||||
chunked_prefill_size = getattr(model.config, "max_position_embeddings", 4096)
|
||||
|
||||
# Only rank 0 creates KTMoEWrapper and loads weights
|
||||
if is_rank_0:
|
||||
wrapper = KTMoEWrapper(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=moe_config.expert_num,
|
||||
num_experts_per_tok=moe_config.num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_config.intermediate_size,
|
||||
gpu_experts_mask=None,
|
||||
num_gpu_experts=0,
|
||||
cpuinfer_threads=getattr(cfg, "kt_num_threads", 1),
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=kt_weight_path or "",
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
method=kt_method,
|
||||
mode="sft",
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
max_cache_depth=getattr(cfg, "kt_max_cache_depth", 2),
|
||||
)
|
||||
|
||||
# Set share_backward_bb and share_cache_pool BEFORE load_weights (config is built during load)
|
||||
wrapper.share_backward_bb = cfg.kt_share_backward_bb
|
||||
wrapper.share_cache_pool = cfg.kt_share_cache_pool
|
||||
|
||||
physical_to_logical_map = torch.arange(moe_config.expert_num, dtype=torch.int64, device="cpu")
|
||||
|
||||
if use_kt_weight_path:
|
||||
logger.debug(f"Layer {layer_idx}: calling wrapper.load_weights() (C++ direct .kt load)")
|
||||
wrapper.load_weights(physical_to_logical_map)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Layer {layer_idx}: calling wrapper.load_weights_from_tensors() "
|
||||
f"(BF16 tensor path, gate_proj shape={gate_proj.shape if gate_proj is not None else None})"
|
||||
)
|
||||
wrapper.load_weights_from_tensors(
|
||||
gate_proj=gate_proj,
|
||||
up_proj=up_proj,
|
||||
down_proj=down_proj,
|
||||
physical_to_logical_map_cpu=physical_to_logical_map,
|
||||
)
|
||||
|
||||
wrapper.gate_proj = None
|
||||
wrapper.up_proj = None
|
||||
wrapper.down_proj = None
|
||||
|
||||
# Create LoRA Experts if enabled
|
||||
lora_experts = None
|
||||
if use_lora_experts:
|
||||
lora_experts = LoRAExperts(
|
||||
num_experts=lora_expert_num,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=lora_expert_intermediate_size,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
layer_wrapper = KTMoELayerWrapper(
|
||||
original_moe=moe_module,
|
||||
wrapper=wrapper,
|
||||
lora_params=None,
|
||||
moe_config=moe_config,
|
||||
hidden_size=hidden_size,
|
||||
layer_idx=layer_idx,
|
||||
lora_experts=lora_experts,
|
||||
)
|
||||
layer_wrapper._fused_experts = _layer_is_fused
|
||||
layer_wrapper._lora_rank = lora_rank
|
||||
|
||||
setattr(layer, moe_config.moe_layer_attr, layer_wrapper)
|
||||
# Base weights have been copied into the C++ kernel's internal BufferB format.
|
||||
# Do not hold a Python-side reference --- it wastes ~1 GB/layer.
|
||||
del gate_proj, up_proj, down_proj
|
||||
|
||||
wrappers.append(layer_wrapper)
|
||||
moe_layer_count += 1
|
||||
|
||||
# Replace original expert weights with meta placeholders.
|
||||
# Experts remain in the model tree (via wrapper.experts) so PEFT can discover them.
|
||||
# Rank 0 already copied weights to C++ kernel via load_weights_from_tensors.
|
||||
_clear_original_expert_weights(moe_module, moe_config)
|
||||
|
||||
logger.info(f"Wrapped {moe_layer_count} MoE layers with KTMoEWrapper")
|
||||
|
||||
# Link wrappers for async backward repack (higher layer triggers repack for lower)
|
||||
for i in range(1, len(wrappers)):
|
||||
if wrappers[i].wrapper is not None and wrappers[i - 1].wrapper is not None:
|
||||
wrappers[i].wrapper._next_backward_wrapper = wrappers[i - 1].wrapper
|
||||
if wrappers and wrappers[0].wrapper is not None:
|
||||
wrappers[0].wrapper._next_backward_wrapper = None
|
||||
|
||||
gc.collect()
|
||||
return wrappers
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Plugin builder
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _build_kt_plugin_from_args(model_args: Any, finetuning_args: Any | None = None):
|
||||
"""
|
||||
Build a KTransformersPlugin from model_args and optional finetuning_args.
|
||||
|
||||
Imported here to avoid circular dependency --- callers that need the plugin
|
||||
class should import it from the appropriate dataclasses module.
|
||||
"""
|
||||
from .config import KTConfig
|
||||
from accelerate.utils.dataclasses import KTransformersPlugin
|
||||
|
||||
kt_config = KTConfig(
|
||||
kt_backend=getattr(model_args, "kt_backend", None),
|
||||
kt_num_threads=getattr(model_args, "kt_num_threads", None),
|
||||
kt_tp_enabled=getattr(model_args, "kt_tp_enabled", None),
|
||||
kt_threadpool_count=getattr(model_args, "kt_threadpool_count", None),
|
||||
kt_max_cache_depth=getattr(model_args, "kt_max_cache_depth", None),
|
||||
kt_num_gpu_experts=getattr(model_args, "kt_num_gpu_experts", None),
|
||||
kt_weight_path=getattr(model_args, "kt_weight_path", None),
|
||||
kt_expert_checkpoint_path=getattr(model_args, "kt_expert_checkpoint_path", None),
|
||||
kt_use_lora_experts=getattr(model_args, "kt_use_lora_experts", None),
|
||||
kt_lora_expert_num=getattr(model_args, "kt_lora_expert_num", None),
|
||||
kt_lora_expert_intermediate_size=getattr(model_args, "kt_lora_expert_intermediate_size", None),
|
||||
kt_lora_rank=getattr(finetuning_args, "lora_rank", None) if finetuning_args else None,
|
||||
kt_lora_alpha=getattr(finetuning_args, "lora_alpha", None) if finetuning_args else None,
|
||||
kt_model_max_length=getattr(model_args, "model_max_length", None),
|
||||
)
|
||||
return KTransformersPlugin(enabled=True, kt_config=kt_config)
|
||||
|
||||
|
||||
def get_kt_loading_kwargs(
|
||||
config,
|
||||
kt_plugin,
|
||||
torch_dtype: torch.dtype | str | None = torch.bfloat16,
|
||||
trust_remote_code: bool | None = None,
|
||||
token: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get kwargs for AutoModel.from_pretrained() for KT loading."""
|
||||
kwargs: dict[str, Any] = {
|
||||
"config": config,
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": "cpu",
|
||||
"low_cpu_mem_usage": True,
|
||||
}
|
||||
if trust_remote_code is not None:
|
||||
kwargs["trust_remote_code"] = trust_remote_code
|
||||
if token is not None:
|
||||
kwargs["token"] = token
|
||||
return kwargs
|
||||
|
||||
|
||||
def _resolve_checkpoint_files(
|
||||
model_name_or_path: str,
|
||||
cache_dir: str | None = None,
|
||||
revision: str | None = None,
|
||||
token: str | None = None,
|
||||
trust_remote_code: bool | None = None,
|
||||
) -> tuple[list[str] | None, dict | None]:
|
||||
"""Resolve HF checkpoint files. Depends on transformers internals."""
|
||||
try:
|
||||
from transformers.modeling_utils import _get_resolved_checkpoint_files
|
||||
except Exception:
|
||||
return None, None
|
||||
try:
|
||||
checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
|
||||
pretrained_model_name_or_path=model_name_or_path,
|
||||
subfolder="",
|
||||
variant=None,
|
||||
gguf_file=None,
|
||||
from_tf=False,
|
||||
from_flax=False,
|
||||
use_safetensors=None,
|
||||
cache_dir=cache_dir,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
local_files_only=False,
|
||||
token=token,
|
||||
user_agent={"file_type": "model", "framework": "pytorch"},
|
||||
revision=revision or "main",
|
||||
commit_hash=None,
|
||||
is_remote_code=bool(trust_remote_code),
|
||||
transformers_explicit_filename=None,
|
||||
)
|
||||
except Exception:
|
||||
return None, None
|
||||
return checkpoint_files, sharded_metadata
|
||||
|
||||
|
||||
def load_kt_model(
|
||||
config,
|
||||
model_args: Any | None = None,
|
||||
finetuning_args: Any | None = None,
|
||||
kt_plugin=None,
|
||||
model_name_or_path: str | None = None,
|
||||
trust_remote_code: bool | None = None,
|
||||
token: str | None = None,
|
||||
torch_dtype: torch.dtype | str | None = torch.bfloat16,
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
"""Load model with KTMoEWrapper backend."""
|
||||
from .arch import get_moe_arch_config, move_non_experts_to_gpu, get_expert_device, KTAMXNotAvailableError, KTAMXConfigError
|
||||
|
||||
if kt_plugin is None:
|
||||
if model_args is None:
|
||||
raise KTAMXConfigError("Either kt_plugin or model_args must be provided to load_kt_model().")
|
||||
kt_plugin = _build_kt_plugin_from_args(model_args, finetuning_args)
|
||||
|
||||
if model_name_or_path is None and model_args is not None:
|
||||
model_name_or_path = getattr(model_args, "model_name_or_path", None)
|
||||
if model_name_or_path is None:
|
||||
raise KTAMXConfigError("model_name_or_path is required to load_kt_model().")
|
||||
|
||||
if trust_remote_code is None and model_args is not None:
|
||||
trust_remote_code = getattr(model_args, "trust_remote_code", None)
|
||||
if token is None and model_args is not None:
|
||||
token = getattr(model_args, "hf_hub_token", None)
|
||||
cache_dir = getattr(model_args, "cache_dir", None) if model_args is not None else None
|
||||
revision = getattr(model_args, "revision", None) if model_args is not None else None
|
||||
|
||||
_ = get_moe_arch_config(config)
|
||||
|
||||
logger.info("Loading model with KTMoEWrapper backend")
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.integrations.kt import set_kt_config, unset_kt_config
|
||||
|
||||
loading_kwargs = get_kt_loading_kwargs(
|
||||
config, kt_plugin, torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code, token=token,
|
||||
)
|
||||
if model_args is not None:
|
||||
for key in ("cache_dir", "revision"):
|
||||
value = getattr(model_args, key, None)
|
||||
if value is not None:
|
||||
loading_kwargs[key] = value
|
||||
loading_kwargs.update(kwargs)
|
||||
|
||||
cfg = _get_kt_config(kt_plugin)
|
||||
|
||||
if getattr(cfg, "kt_skip_expert_loading", None) is None:
|
||||
checkpoint_files, sharded_metadata = _resolve_checkpoint_files(
|
||||
model_name_or_path=model_name_or_path,
|
||||
cache_dir=cache_dir, revision=revision,
|
||||
token=token, trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if checkpoint_files and all(f.endswith(".safetensors") for f in checkpoint_files):
|
||||
if getattr(cfg, "kt_weight_path", None) is None:
|
||||
cfg.kt_skip_expert_loading = True
|
||||
else:
|
||||
cfg.kt_skip_expert_loading = False
|
||||
cfg.kt_checkpoint_files = checkpoint_files
|
||||
cfg.kt_sharded_metadata = sharded_metadata
|
||||
else:
|
||||
cfg.kt_skip_expert_loading = False
|
||||
|
||||
set_kt_config(kt_plugin)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **loading_kwargs)
|
||||
finally:
|
||||
unset_kt_config()
|
||||
|
||||
moe_config = get_moe_arch_config(config)
|
||||
move_non_experts_to_gpu(model, moe_config, device="cuda:0")
|
||||
|
||||
existing_wrappers = getattr(model, "_kt_wrappers", None)
|
||||
if existing_wrappers:
|
||||
logger.info(f"MoE layers already wrapped ({len(existing_wrappers)} layers), skipping re-wrap")
|
||||
wrappers = existing_wrappers
|
||||
else:
|
||||
wrappers = wrap_moe_layers_with_kt_wrapper(model, kt_plugin)
|
||||
|
||||
model._kt_wrappers = wrappers
|
||||
model._kt_tp_enabled = bool(getattr(cfg, "kt_tp_enabled", False))
|
||||
model._kt_use_lora_experts = bool(getattr(cfg, "kt_use_lora_experts", False))
|
||||
|
||||
logger.info("Model loaded with KTMoEWrapper backend successfully")
|
||||
return model
|
||||
Loading…
Add table
Add a link
Reference in a new issue