refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
mrhaoxx 2026-04-09 14:17:50 +08:00
parent a98d544833
commit 020eb929f7
5 changed files with 127 additions and 172 deletions

View file

@ -26,7 +26,6 @@ from .dist_utils import (
_checkpoint_hook_mode,
_dist_gather_varlen_to_rank0,
_dist_scatter_varlen_from_rank0,
_is_in_checkpoint_first_forward,
_qlen_offsets,
)
@ -123,23 +122,9 @@ class KTMoELayerWrapper(nn.Module):
and torch.is_grad_enabled()
and (hidden_states.requires_grad or topk_weights.requires_grad or train_lora)
)
ckpt_hook_mode = _checkpoint_hook_mode()
in_ckpt_recompute = ckpt_hook_mode == "recompute"
in_ckpt_first_forward = ckpt_hook_mode == "first_forward"
if ckpt_hook_mode in ("none", "other", "error"):
# Fallback for environments where hook-top probing is unavailable.
in_ckpt_first_forward = _is_in_checkpoint_first_forward()
if in_ckpt_recompute:
# Recompute must be treated as non-first-forward in diagnostics.
in_ckpt_first_forward = False
# Keep KT autograd path whenever backward is needed. Disabling it in
# checkpoint first-forward prevents KTMoEFunction.backward from running.
use_autograd_path = save_for_backward
save_for_backward_submit = use_autograd_path
# Only suppress cache when we have high-confidence first_forward detection
# via the saved_tensors_hooks stack. The stack-walk fallback is too fragile
# for a correctness-critical decision — it only logs.
if ckpt_hook_mode == "first_forward":
if _checkpoint_hook_mode() == "first_forward":
save_for_backward_submit = False
if train_lora and self._lora_pointers_dirty: