refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
mrhaoxx 2026-04-09 14:17:50 +08:00
parent a98d544833
commit 020eb929f7
5 changed files with 127 additions and 172 deletions

View file

@ -9,7 +9,6 @@ This is a leaf module — no imports from other sft/ submodules.
from __future__ import annotations
import inspect
from contextlib import nullcontext
from typing import Any
@ -125,18 +124,6 @@ def _dist_scatter_varlen_from_rank0(
return local_out
def _is_in_checkpoint_first_forward() -> bool:
"""Best-effort detection for non-reentrant checkpoint first forward."""
try:
for frame_info in inspect.stack(context=0):
fn = frame_info.function
file = frame_info.filename or ""
if fn == "custom_gradient_checkpointing_func" and file.endswith("checkpointing.py"):
return True
except Exception:
return False
return False
def _checkpoint_hook_mode() -> str:
"""Infer checkpoint phase from current saved_tensors_hooks top.