kvcache-ai-ktransformers/kt-kernel/python/sft/dist_utils.py
mrhaoxx 9544a8960d
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
feat(sft): AMX MoE SFT backend with LoRA support (#1936)
* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
2026-04-22 11:27:01 +08:00

171 lines
5.5 KiB
Python

# Distributed and checkpoint utilities for SFT
# SPDX-License-Identifier: Apache-2.0
"""
Shared distributed communication and gradient-checkpoint detection helpers.
This is a leaf module — no imports from other sft/ submodules.
"""
from __future__ import annotations
from contextlib import nullcontext
from typing import Any
import torch
def _all_gather_qlens(local_qlen: int, device: torch.device, world_size: int) -> list[int]:
import torch.distributed as dist
local_qlen_t = torch.tensor([int(local_qlen)], device=device, dtype=torch.int64)
gathered = [torch.empty(1, device=device, dtype=torch.int64) for _ in range(world_size)]
dist.all_gather(gathered, local_qlen_t)
return [int(t.item()) for t in gathered]
def _qlen_offsets(all_qlens: list[int]) -> list[int]:
offsets = [0]
for q in all_qlens:
offsets.append(offsets[-1] + int(q))
return offsets
def _dist_gather_varlen_to_rank0(
local_tensor: torch.Tensor,
*,
all_qlens: list[int],
rank: int,
world_size: int,
) -> list[torch.Tensor] | None:
import torch.distributed as dist
local_tensor = local_tensor.contiguous()
local_expected = int(all_qlens[rank])
if local_tensor.shape[0] != local_expected:
raise RuntimeError(
f"Local leading dim mismatch on rank {rank}: got {local_tensor.shape[0]}, expected {local_expected}"
)
if rank == 0:
gathered: list[torch.Tensor | None] = [None] * world_size
gathered[0] = local_tensor
ops: list[dist.P2POp] = []
for src in range(1, world_size):
qlen_src = int(all_qlens[src])
recv_shape = (qlen_src, *local_tensor.shape[1:])
recv = torch.empty(recv_shape, device=local_tensor.device, dtype=local_tensor.dtype)
gathered[src] = recv
if qlen_src > 0:
ops.append(dist.P2POp(dist.irecv, recv, src))
if ops:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
out: list[torch.Tensor] = []
for idx, t in enumerate(gathered):
if t is None:
raise RuntimeError(f"Missing gathered tensor for rank {idx} on rank0.")
out.append(t)
return out
if local_expected > 0:
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, local_tensor, 0)])
for req in reqs:
req.wait()
return None
def _dist_scatter_varlen_from_rank0(
*,
rank0_chunks: list[torch.Tensor] | None,
all_qlens: list[int],
rank: int,
world_size: int,
feature_shape: tuple[int, ...],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
import torch.distributed as dist
local_qlen = int(all_qlens[rank])
local_out = torch.empty((local_qlen, *feature_shape), device=device, dtype=dtype)
if rank == 0:
if rank0_chunks is None or len(rank0_chunks) != world_size:
raise RuntimeError("rank0_chunks must contain one chunk per rank on rank0.")
if int(rank0_chunks[0].shape[0]) != local_qlen:
raise RuntimeError(
f"Rank0 local chunk mismatch: got {rank0_chunks[0].shape[0]}, expected {local_qlen}"
)
if local_qlen > 0:
local_out.copy_(rank0_chunks[0])
ops: list[dist.P2POp] = []
for dst in range(1, world_size):
qlen_dst = int(all_qlens[dst])
if qlen_dst <= 0:
continue
chunk = rank0_chunks[dst].contiguous()
if int(chunk.shape[0]) != qlen_dst:
raise RuntimeError(
f"Rank{dst} chunk mismatch on rank0: got {chunk.shape[0]}, expected {qlen_dst}"
)
ops.append(dist.P2POp(dist.isend, chunk, dst))
if ops:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return local_out
if local_qlen > 0:
reqs = dist.batch_isend_irecv([dist.P2POp(dist.irecv, local_out, 0)])
for req in reqs:
req.wait()
return local_out
def _checkpoint_hook_mode() -> str:
"""Infer checkpoint phase from current saved_tensors_hooks top.
Returns one of:
- "first_forward": non-reentrant checkpoint's _checkpoint_hook
- "recompute": non-reentrant checkpoint's _recomputation_hook
- "none": no default saved_tensors_hooks on top
- "other": unknown hook stack entry
- "error": failed to query hook stack
"""
try:
top = torch._C._autograd._top_saved_tensors_default_hooks(False)
except Exception:
return "error"
if top is None:
return "none"
try:
pack_fn, _ = top
mod = getattr(pack_fn, "__module__", "")
qual = getattr(pack_fn, "__qualname__", getattr(pack_fn, "__name__", ""))
tag = f"{mod}.{qual}"
except Exception:
return "other"
if "_recomputation_hook.__init__.<locals>.pack_hook" in tag:
return "recompute"
if "_checkpoint_hook.__init__.<locals>.pack_hook" in tag:
return "first_forward"
return "other"
def _maybe_zero3_gathered_parameters(params: list[torch.nn.Parameter]):
if not params:
return nullcontext()
try:
from transformers.integrations import is_deepspeed_zero3_enabled
except Exception:
return nullcontext()
if not is_deepspeed_zero3_enabled():
return nullcontext()
try:
import deepspeed # type: ignore
except Exception:
return nullcontext()
return deepspeed.zero.GatheredParameters(params, modifier_rank=0)