mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
407 lines
17 KiB
Python
407 lines
17 KiB
Python
# KTMoELayerWrapper — nn.Module replacing HF MoE layers for SFT
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""
|
|
KTMoELayerWrapper: drop-in nn.Module replacement for HuggingFace MoE layers.
|
|
|
|
Delegates expert computation to the C++ KTMoEWrapper backend, with support
|
|
for gradient checkpointing, PEFT LoRA on experts, LoRA Experts (separate
|
|
small MLPs on GPU), shared experts, and multi-GPU rank-0-only execution.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .arch import MOEArchConfig
|
|
from .autograd import KTMoEFunction
|
|
from .dist_utils import (
|
|
_all_gather_qlens,
|
|
_checkpoint_hook_mode,
|
|
_dist_gather_varlen_to_rank0,
|
|
_dist_scatter_varlen_from_rank0,
|
|
_is_in_checkpoint_first_forward,
|
|
_qlen_offsets,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_KT_SFT_DEBUG = os.environ.get("KT_SFT_DEBUG", "0") == "1"
|
|
|
|
|
|
class KTMoELayerWrapper(nn.Module):
|
|
"""Wrapper for MoE layer using KTMoEWrapper."""
|
|
|
|
def __init__(
|
|
self,
|
|
original_moe: nn.Module,
|
|
wrapper: Any,
|
|
lora_params: dict[str, nn.Parameter] | None, # Kept for backward compatibility, but ignored
|
|
moe_config: MOEArchConfig,
|
|
hidden_size: int,
|
|
layer_idx: int,
|
|
lora_experts: "LoRAExperts | None" = None,
|
|
):
|
|
super().__init__()
|
|
self._is_kt_moe_wrapper = True
|
|
|
|
self.wrapper = wrapper
|
|
self.moe_config = moe_config
|
|
self.hidden_size = hidden_size
|
|
self.layer_idx = layer_idx
|
|
self.router_type = moe_config.router_type
|
|
|
|
# IMPORTANT: Register submodules in the SAME ORDER as original MoE module
|
|
# so that PEFT's named_modules() traversal order matches baseline.
|
|
# This ensures kaiming_uniform_ calls happen in the same sequence.
|
|
# Qwen3MoeSparseMoeBlock order: gate FIRST, then experts.
|
|
|
|
# 1. gate/router FIRST - keep original attribute name for PEFT compatibility
|
|
router_attr = moe_config.router_attr # "gate" for Qwen3/DeepSeek
|
|
setattr(self, router_attr, getattr(original_moe, router_attr, None))
|
|
self._router_attr = router_attr
|
|
|
|
# 2. experts SECOND (this is what PEFT targets for LoRA)
|
|
experts_attr = moe_config.experts_attr # typically "experts"
|
|
setattr(self, experts_attr, getattr(original_moe, experts_attr, None))
|
|
self._experts_attr = experts_attr
|
|
|
|
# 3. shared_experts (if any)
|
|
if moe_config.has_shared_experts and hasattr(original_moe, "shared_experts"):
|
|
self.shared_experts = original_moe.shared_experts
|
|
else:
|
|
self.shared_experts = None
|
|
|
|
# 4. lora_experts (separate LoRA expert MLPs, different from PEFT LoRA on experts)
|
|
self.lora_experts = lora_experts
|
|
|
|
# PEFT LoRA tracking (set by kt_adapt_peft_lora)
|
|
# _peft_lora_modules: {expert_idx: {proj_name: (lora_A, lora_B)}}
|
|
self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None
|
|
self._peft_lora_rank: int = 0
|
|
self._peft_lora_alpha: float = 0.0
|
|
self._skip_lora: bool = False # True when using SkipLoRA backend (no LoRA on experts)
|
|
|
|
self._lora_pointers_dirty = False
|
|
|
|
def _apply(self, fn, recurse=True):
|
|
# Protect experts from device transfer (PEFT LoRA should stay on CPU for KT)
|
|
saved_experts = None
|
|
experts_attr = getattr(self, '_experts_attr', None)
|
|
|
|
if experts_attr is not None and getattr(self, experts_attr, None) is not None:
|
|
saved_experts = getattr(self, experts_attr)
|
|
self._modules.pop(experts_attr, None)
|
|
|
|
result = super()._apply(fn, recurse)
|
|
|
|
if saved_experts is not None:
|
|
self._modules[experts_attr] = saved_experts
|
|
|
|
return result
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
import torch.distributed as dist
|
|
dist_on = dist.is_initialized() and dist.get_world_size() > 1
|
|
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
|
|
# Check if we need to use distributed broadcast (only rank 0 has KT kernel)
|
|
use_broadcast = dist_on and self.wrapper is None
|
|
|
|
topk_ids, topk_weights = self._compute_routing(hidden_states)
|
|
|
|
train_lora = self._peft_lora_modules is not None and len(self._peft_lora_modules) > 0
|
|
|
|
save_for_backward = (
|
|
self.training
|
|
and torch.is_grad_enabled()
|
|
and (hidden_states.requires_grad or topk_weights.requires_grad or train_lora)
|
|
)
|
|
ckpt_hook_mode = _checkpoint_hook_mode()
|
|
in_ckpt_recompute = ckpt_hook_mode == "recompute"
|
|
in_ckpt_first_forward = ckpt_hook_mode == "first_forward"
|
|
if ckpt_hook_mode in ("none", "other", "error"):
|
|
# Fallback for environments where hook-top probing is unavailable.
|
|
in_ckpt_first_forward = _is_in_checkpoint_first_forward()
|
|
if in_ckpt_recompute:
|
|
# Recompute must be treated as non-first-forward in diagnostics.
|
|
in_ckpt_first_forward = False
|
|
# Keep KT autograd path whenever backward is needed. Disabling it in
|
|
# checkpoint first-forward prevents KTMoEFunction.backward from running.
|
|
use_autograd_path = save_for_backward
|
|
save_for_backward_submit = use_autograd_path
|
|
# Only suppress cache when we have high-confidence first_forward detection
|
|
# via the saved_tensors_hooks stack. The stack-walk fallback is too fragile
|
|
# for a correctness-critical decision — it only logs.
|
|
if ckpt_hook_mode == "first_forward":
|
|
save_for_backward_submit = False
|
|
|
|
if train_lora and self._lora_pointers_dirty:
|
|
self.update_lora_pointers()
|
|
self._lora_pointers_dirty = False
|
|
|
|
gpu_output, all_qlens = self._submit_and_compute_gpu(
|
|
hidden_states,
|
|
topk_ids,
|
|
topk_weights,
|
|
save_for_backward_submit,
|
|
)
|
|
|
|
# Use KTMoEFunction whenever backward is needed so KT backward and LoRA
|
|
# gradient paths remain connected.
|
|
if use_autograd_path:
|
|
lora_ref = hidden_states.new_empty(())
|
|
if train_lora and self._peft_lora_modules:
|
|
for expert_loras in self._peft_lora_modules.values():
|
|
for lora_A, lora_B in expert_loras.values():
|
|
if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad:
|
|
lora_ref = lora_A.weight
|
|
break
|
|
if lora_ref.numel() > 0:
|
|
break
|
|
|
|
moe_output = KTMoEFunction.apply(
|
|
hidden_states,
|
|
topk_ids,
|
|
topk_weights,
|
|
self.wrapper,
|
|
lora_ref,
|
|
self.hidden_size,
|
|
self.moe_config.num_experts_per_tok,
|
|
self.layer_idx,
|
|
save_for_backward,
|
|
train_lora,
|
|
all_qlens,
|
|
)
|
|
else:
|
|
moe_output = self._sync_forward_output_no_autograd(
|
|
hidden_states=hidden_states,
|
|
all_qlens=all_qlens,
|
|
)
|
|
|
|
if gpu_output is not None:
|
|
moe_output = moe_output + gpu_output
|
|
|
|
return moe_output
|
|
|
|
def _sync_forward_output_no_autograd(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
all_qlens: list[int] | tuple[int, ...] | None,
|
|
) -> torch.Tensor:
|
|
"""Sync CPU expert output without creating KTMoEFunction autograd nodes."""
|
|
import torch.distributed as dist
|
|
|
|
original_device = hidden_states.device
|
|
original_dtype = hidden_states.dtype
|
|
batch_size, seq_len, _ = hidden_states.shape
|
|
qlen = batch_size * seq_len
|
|
|
|
dist_on = dist.is_initialized() and dist.get_world_size() > 1
|
|
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
world_size = dist.get_world_size() if dist_on else 1
|
|
|
|
if dist_on:
|
|
if all_qlens is None:
|
|
all_qlens_list = _all_gather_qlens(qlen, original_device, world_size)
|
|
else:
|
|
all_qlens_list = [int(q) for q in all_qlens]
|
|
if len(all_qlens_list) != world_size:
|
|
raise RuntimeError(
|
|
f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}"
|
|
)
|
|
if int(all_qlens_list[rank]) != qlen:
|
|
raise RuntimeError(
|
|
f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}"
|
|
)
|
|
total_qlen = sum(all_qlens_list)
|
|
|
|
if rank == 0:
|
|
if self.wrapper is None:
|
|
raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.")
|
|
cpu_output = self.wrapper.sync_forward(output_device=original_device)
|
|
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size)
|
|
offsets = _qlen_offsets(all_qlens_list)
|
|
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
|
|
else:
|
|
scatter_list = None
|
|
|
|
output_flat = _dist_scatter_varlen_from_rank0(
|
|
rank0_chunks=scatter_list,
|
|
all_qlens=all_qlens_list,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
feature_shape=(self.hidden_size,),
|
|
device=original_device,
|
|
dtype=original_dtype,
|
|
)
|
|
output = output_flat.view(batch_size, seq_len, self.hidden_size)
|
|
del output_flat
|
|
return output
|
|
|
|
if self.wrapper is not None:
|
|
cpu_output = self.wrapper.sync_forward(output_device=original_device)
|
|
output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype)
|
|
return output
|
|
|
|
return torch.empty(batch_size, seq_len, self.hidden_size, device=original_device, dtype=original_dtype)
|
|
|
|
def _compute_routing(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Run routing under no_grad to avoid creating autograd nodes whose
|
|
# SavedVariables become orphan holders inside gradient checkpoint.
|
|
# The gate is frozen during LoRA fine-tuning and the main gradient
|
|
# flows through KTMoEFunction.backward()'s grad_input, so the
|
|
# routing gradient contribution to hidden_states can be safely dropped.
|
|
with torch.no_grad():
|
|
router = getattr(self, self._router_attr)
|
|
if self.router_type == "deepseek_gate":
|
|
# DeepSeek V3's MoEGate has `assert not self.training` in its noaux_tc
|
|
# routing path because the HF model is an inference-only port.
|
|
# For LoRA fine-tuning the router is frozen, so eval() is safe.
|
|
was_training = router.training
|
|
if was_training:
|
|
router.eval()
|
|
router_output = router(hidden_states)
|
|
if was_training:
|
|
router.train()
|
|
if len(router_output) == 2:
|
|
topk_ids, topk_weights = router_output
|
|
else:
|
|
topk_ids, topk_weights = router_output[0], router_output[1]
|
|
if topk_weights.is_floating_point():
|
|
topk_weights = topk_weights.to(torch.bfloat16)
|
|
return topk_ids, topk_weights
|
|
|
|
router_logits = router(hidden_states.view(-1, self.hidden_size))
|
|
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
|
topk_weights, topk_ids = torch.topk(routing_weights, self.moe_config.num_experts_per_tok, dim=-1)
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
topk_weights = topk_weights.to(torch.bfloat16)
|
|
return topk_ids, topk_weights
|
|
|
|
def _submit_and_compute_gpu(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
save_for_backward: bool,
|
|
) -> tuple[torch.Tensor | None, list[int] | None]:
|
|
import torch.distributed as dist
|
|
|
|
batch_size, seq_len, _ = hidden_states.shape
|
|
original_device = hidden_states.device
|
|
original_dtype = hidden_states.dtype
|
|
|
|
dist_on = dist.is_initialized() and dist.get_world_size() > 1
|
|
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
world_size = dist.get_world_size() if dist_on else 1
|
|
|
|
qlen = batch_size * seq_len
|
|
|
|
if dist_on:
|
|
all_qlens = _all_gather_qlens(qlen, original_device, world_size)
|
|
if int(all_qlens[rank]) != qlen:
|
|
raise RuntimeError(
|
|
f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}"
|
|
)
|
|
total_qlen = sum(all_qlens)
|
|
|
|
hs_flat = hidden_states.view(qlen, self.hidden_size).contiguous()
|
|
expert_ids = topk_ids.view(qlen, self.moe_config.num_experts_per_tok).contiguous()
|
|
weights = topk_weights.view(qlen, self.moe_config.num_experts_per_tok).contiguous()
|
|
|
|
submit_hs = hs_flat.detach()
|
|
submit_ids = expert_ids.detach()
|
|
submit_wts = weights.detach()
|
|
|
|
gathered_hs = _dist_gather_varlen_to_rank0(
|
|
submit_hs,
|
|
all_qlens=all_qlens,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
gathered_ids = _dist_gather_varlen_to_rank0(
|
|
submit_ids,
|
|
all_qlens=all_qlens,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
gathered_wts = _dist_gather_varlen_to_rank0(
|
|
submit_wts,
|
|
all_qlens=all_qlens,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
|
|
if rank == 0:
|
|
all_hs = torch.cat(gathered_hs, dim=0)
|
|
all_ids = torch.cat(gathered_ids, dim=0)
|
|
all_wts = torch.cat(gathered_wts, dim=0)
|
|
self.wrapper.submit_forward(
|
|
all_hs,
|
|
all_ids,
|
|
all_wts,
|
|
save_for_backward=save_for_backward,
|
|
)
|
|
|
|
# Keep shared/lora experts local to avoid qlen_max-style amplification.
|
|
gpu_output = None
|
|
if self.shared_experts is not None:
|
|
gpu_output = self.shared_experts(hidden_states)
|
|
gpu_output = gpu_output.to(dtype=original_dtype)
|
|
|
|
if self.lora_experts is not None:
|
|
lora_out = self.lora_experts(hidden_states)
|
|
gpu_output = lora_out if gpu_output is None else gpu_output + lora_out
|
|
|
|
return gpu_output, all_qlens
|
|
|
|
else:
|
|
# ---- Single-GPU path: submit + GPU compute ----
|
|
input_flat = hidden_states.view(qlen, self.hidden_size)
|
|
expert_ids = topk_ids.view(qlen, self.moe_config.num_experts_per_tok)
|
|
weights = topk_weights.view(qlen, self.moe_config.num_experts_per_tok)
|
|
|
|
# Avoid passing graph-attached tensors into C++ cache.
|
|
submit_hs = input_flat.detach()
|
|
submit_ids = expert_ids.detach()
|
|
submit_wts = weights.detach()
|
|
self.wrapper.submit_forward(
|
|
submit_hs,
|
|
submit_ids,
|
|
submit_wts,
|
|
save_for_backward=save_for_backward,
|
|
)
|
|
|
|
# GPU compute: shared_experts + lora_experts
|
|
gpu_output = None
|
|
if self.shared_experts is not None:
|
|
gpu_output = self.shared_experts(hidden_states)
|
|
if self.lora_experts is not None:
|
|
lora_out = self.lora_experts(hidden_states)
|
|
gpu_output = lora_out if gpu_output is None else gpu_output + lora_out
|
|
|
|
return gpu_output, None
|
|
|
|
def update_lora_pointers(self):
|
|
"""Sync PEFT LoRA weights to C++ kernel after optimizer update."""
|
|
# Skip if wrapper is None (non-rank-0 processes)
|
|
if self.wrapper is None:
|
|
return
|
|
# Skip if wrapper is not properly initialized
|
|
if not getattr(self.wrapper, "_weights_loaded", False):
|
|
logger.warning(f"Layer {self.layer_idx}: Skipping update_lora_pointers - weights not loaded")
|
|
return
|
|
if not getattr(self.wrapper, "_lora_initialized", False):
|
|
logger.warning(f"Layer {self.layer_idx}: Skipping update_lora_pointers - LoRA not initialized")
|
|
return
|
|
|
|
# PEFT weights are views into wrapper's contiguous buffers —
|
|
# optimizer.step() already updated them in-place, just re-sync to C++.
|
|
self.wrapper.update_lora_weights()
|