mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code
- KTConfig fields all use kt_ prefix matching dict keys — eliminates _OLD_TO_NEW mapping and prefix-stripping in wrapper.py - Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing is on (via training_args.py), flows through to C++ cache allocation - Remove dead checkpoint detection code: in_ckpt_recompute, in_ckpt_first_forward vars (assigned but never read), fallback _is_in_checkpoint_first_forward() function, unused inspect import - Remove redundant env var fallbacks in wrapper.py for share_backward_bb and share_cache_pool (KTConfig.__post_init__ already handles env vars) - Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check Verified: Qwen3-235B 3-step training on sap4, loss matches baseline (1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a98d544833
commit
020eb929f7
5 changed files with 127 additions and 172 deletions
|
|
@ -26,7 +26,6 @@ from .dist_utils import (
|
|||
_checkpoint_hook_mode,
|
||||
_dist_gather_varlen_to_rank0,
|
||||
_dist_scatter_varlen_from_rank0,
|
||||
_is_in_checkpoint_first_forward,
|
||||
_qlen_offsets,
|
||||
)
|
||||
|
||||
|
|
@ -123,23 +122,9 @@ class KTMoELayerWrapper(nn.Module):
|
|||
and torch.is_grad_enabled()
|
||||
and (hidden_states.requires_grad or topk_weights.requires_grad or train_lora)
|
||||
)
|
||||
ckpt_hook_mode = _checkpoint_hook_mode()
|
||||
in_ckpt_recompute = ckpt_hook_mode == "recompute"
|
||||
in_ckpt_first_forward = ckpt_hook_mode == "first_forward"
|
||||
if ckpt_hook_mode in ("none", "other", "error"):
|
||||
# Fallback for environments where hook-top probing is unavailable.
|
||||
in_ckpt_first_forward = _is_in_checkpoint_first_forward()
|
||||
if in_ckpt_recompute:
|
||||
# Recompute must be treated as non-first-forward in diagnostics.
|
||||
in_ckpt_first_forward = False
|
||||
# Keep KT autograd path whenever backward is needed. Disabling it in
|
||||
# checkpoint first-forward prevents KTMoEFunction.backward from running.
|
||||
use_autograd_path = save_for_backward
|
||||
save_for_backward_submit = use_autograd_path
|
||||
# Only suppress cache when we have high-confidence first_forward detection
|
||||
# via the saved_tensors_hooks stack. The stack-walk fallback is too fragile
|
||||
# for a correctness-critical decision — it only logs.
|
||||
if ckpt_hook_mode == "first_forward":
|
||||
if _checkpoint_hook_mode() == "first_forward":
|
||||
save_for_backward_submit = False
|
||||
|
||||
if train_lora and self._lora_pointers_dirty:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue