refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
mrhaoxx 2026-04-09 14:17:50 +08:00
parent a98d544833
commit 020eb929f7
5 changed files with 127 additions and 172 deletions

View file

@ -11,6 +11,7 @@ KTransformersPlugin.kt_config (similar to DeepSpeedPlugin.hf_ds_config).
from __future__ import annotations
import dataclasses
import os
from dataclasses import dataclass, field
from typing import Any, Callable
@ -42,83 +43,97 @@ class KTConfig:
"""
KT-Kernel configuration for SFT training.
All kt-kernel-specific settings live here. Accelerate's KTransformersPlugin
holds a reference to this via its `kt_config` field (similar to
DeepSpeedPlugin.hf_ds_config).
All field names use the ``kt_`` prefix so they match the dict keys used in
HfTrainerKTConfig / YAML configs. This means ``KTConfig(**dict)`` works
directly no name-mapping or prefix-stripping needed.
Can be created from:
- Direct construction: KTConfig(backend="AMXBF16", weight_path="/path/...")
- Direct construction: KTConfig(kt_backend="AMXBF16", kt_weight_path="/path/...")
- Dict: KTConfig(**config_dict)
- Environment variables: KTConfig() reads ACCELERATE_KT_* env vars as defaults
"""
# Backend selection
backend: str | None = None
num_threads: int | None = None
tp_enabled: bool | None = None
threadpool_count: int | None = None
kt_backend: str | None = None
kt_num_threads: int | None = None
kt_tp_enabled: bool | None = None
kt_threadpool_count: int | None = None
# Weight loading
weight_path: str | None = None
expert_checkpoint_path: str | None = None
num_gpu_experts: int | None = None
skip_expert_loading: bool | None = None
share_backward_bb: bool | None = None
kt_weight_path: str | None = None
kt_expert_checkpoint_path: str | None = None
kt_num_gpu_experts: int | None = None
kt_skip_expert_loading: bool | None = None
kt_share_backward_bb: bool | None = None
kt_share_cache_pool: bool | None = None
# Cache
max_cache_depth: int | None = None
model_max_length: int | None = None
kt_max_cache_depth: int | None = None
kt_model_max_length: int | None = None
# LoRA
lora_rank: int | None = None
lora_alpha: float | None = None
kt_lora_rank: int | None = None
kt_lora_alpha: float | None = None
# LoRA Experts (GPU-side extra experts)
use_lora_experts: bool | None = None
lora_expert_num: int | None = None
lora_expert_intermediate_size: int | None = None
kt_use_lora_experts: bool | None = None
kt_lora_expert_num: int | None = None
kt_lora_expert_intermediate_size: int | None = None
# Runtime state (set during wrapping, not by user)
checkpoint_files: list[str] | None = None
sharded_metadata: dict | None = None
kt_checkpoint_files: list[str] | None = None
kt_sharded_metadata: dict | None = None
# Custom wrapping
wrap_fn: Callable[..., Any] | None = None
wrap_kwargs: dict[str, Any] | None = None
kt_wrap_fn: Callable[..., Any] | None = None
kt_wrap_kwargs: dict[str, Any] | None = None
@classmethod
def from_object(cls, obj: Any) -> "KTConfig":
"""Create KTConfig from an attribute-based object (HfTrainerKTConfig, etc.)."""
_field_names = {f.name for f in dataclasses.fields(cls)}
kwargs: dict[str, Any] = {}
for name in _field_names:
val = getattr(obj, name, None)
if val is not None:
kwargs[name] = val
return cls(**kwargs)
def __post_init__(self):
if self.backend is None:
self.backend = os.environ.get("ACCELERATE_KT_BACKEND", "AMXBF16")
if self.num_threads is None:
self.num_threads = _env_int("ACCELERATE_KT_NUM_THREADS", 1)
if self.tp_enabled is None:
self.tp_enabled = _env_bool("ACCELERATE_KT_TP_ENABLED", False)
if self.threadpool_count is None:
self.threadpool_count = _env_int("ACCELERATE_KT_THREADPOOL_COUNT", 1)
if self.weight_path is None:
self.weight_path = os.environ.get("ACCELERATE_KT_WEIGHT_PATH", None)
if self.expert_checkpoint_path is None:
self.expert_checkpoint_path = os.environ.get("ACCELERATE_KT_EXPERT_CHECKPOINT_PATH", None)
if self.num_gpu_experts is None:
self.num_gpu_experts = _env_int("ACCELERATE_KT_NUM_GPU_EXPERTS", 0)
if self.max_cache_depth is None:
self.max_cache_depth = _env_int("ACCELERATE_KT_MAX_CACHE_DEPTH", 2)
if self.share_backward_bb is None:
self.share_backward_bb = _env_bool("ACCELERATE_KT_SHARE_BACKWARD_BB", False)
if self.use_lora_experts is None:
self.use_lora_experts = _env_bool("ACCELERATE_KT_USE_LORA_EXPERTS", False)
if self.lora_expert_num is None:
self.lora_expert_num = _env_int("ACCELERATE_KT_LORA_EXPERT_NUM", None)
if self.lora_expert_intermediate_size is None:
self.lora_expert_intermediate_size = _env_int("ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE", None)
if self.lora_rank is None:
self.lora_rank = _env_int("ACCELERATE_KT_LORA_RANK", None)
if self.lora_alpha is None:
self.lora_alpha = _env_float("ACCELERATE_KT_LORA_ALPHA", None)
if self.lora_alpha is None and self.lora_rank is not None:
self.lora_alpha = float(self.lora_rank * 2)
if self.model_max_length is None:
self.model_max_length = _env_int("ACCELERATE_KT_MODEL_MAX_LENGTH", None)
if self.skip_expert_loading is None:
if self.kt_backend is None:
self.kt_backend = os.environ.get("ACCELERATE_KT_BACKEND", "AMXBF16")
if self.kt_num_threads is None:
self.kt_num_threads = _env_int("ACCELERATE_KT_NUM_THREADS", 1)
if self.kt_tp_enabled is None:
self.kt_tp_enabled = _env_bool("ACCELERATE_KT_TP_ENABLED", False)
if self.kt_threadpool_count is None:
self.kt_threadpool_count = _env_int("ACCELERATE_KT_THREADPOOL_COUNT", 1)
if self.kt_weight_path is None:
self.kt_weight_path = os.environ.get("ACCELERATE_KT_WEIGHT_PATH", None)
if self.kt_expert_checkpoint_path is None:
self.kt_expert_checkpoint_path = os.environ.get("ACCELERATE_KT_EXPERT_CHECKPOINT_PATH", None)
if self.kt_num_gpu_experts is None:
self.kt_num_gpu_experts = _env_int("ACCELERATE_KT_NUM_GPU_EXPERTS", 0)
if self.kt_max_cache_depth is None:
self.kt_max_cache_depth = _env_int("ACCELERATE_KT_MAX_CACHE_DEPTH", 2)
if self.kt_share_backward_bb is None:
self.kt_share_backward_bb = _env_bool("ACCELERATE_KT_SHARE_BACKWARD_BB", False)
if self.kt_share_cache_pool is None:
self.kt_share_cache_pool = _env_bool("ACCELERATE_KT_SHARE_CACHE_POOL", False)
if self.kt_use_lora_experts is None:
self.kt_use_lora_experts = _env_bool("ACCELERATE_KT_USE_LORA_EXPERTS", False)
if self.kt_lora_expert_num is None:
self.kt_lora_expert_num = _env_int("ACCELERATE_KT_LORA_EXPERT_NUM", None)
if self.kt_lora_expert_intermediate_size is None:
self.kt_lora_expert_intermediate_size = _env_int("ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE", None)
if self.kt_lora_rank is None:
self.kt_lora_rank = _env_int("ACCELERATE_KT_LORA_RANK", None)
if self.kt_lora_alpha is None:
self.kt_lora_alpha = _env_float("ACCELERATE_KT_LORA_ALPHA", None)
if self.kt_lora_alpha is None and self.kt_lora_rank is not None:
self.kt_lora_alpha = float(self.kt_lora_rank * 2)
if self.kt_model_max_length is None:
self.kt_model_max_length = _env_int("ACCELERATE_KT_MODEL_MAX_LENGTH", None)
if self.kt_skip_expert_loading is None:
if "ACCELERATE_KT_SKIP_EXPERT_LOADING" in os.environ:
self.skip_expert_loading = _env_bool("ACCELERATE_KT_SKIP_EXPERT_LOADING", True)
self.kt_skip_expert_loading = _env_bool("ACCELERATE_KT_SKIP_EXPERT_LOADING", True)