mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 04:09:52 +00:00
refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code
- KTConfig fields all use kt_ prefix matching dict keys — eliminates _OLD_TO_NEW mapping and prefix-stripping in wrapper.py - Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing is on (via training_args.py), flows through to C++ cache allocation - Remove dead checkpoint detection code: in_ckpt_recompute, in_ckpt_first_forward vars (assigned but never read), fallback _is_in_checkpoint_first_forward() function, unused inspect import - Remove redundant env var fallbacks in wrapper.py for share_backward_bb and share_cache_pool (KTConfig.__post_init__ already handles env vars) - Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check Verified: Qwen3-235B 3-step training on sap4, loss matches baseline (1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a98d544833
commit
020eb929f7
5 changed files with 127 additions and 172 deletions
|
|
@ -11,6 +11,7 @@ KTransformersPlugin.kt_config (similar to DeepSpeedPlugin.hf_ds_config).
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
|
@ -42,83 +43,97 @@ class KTConfig:
|
|||
"""
|
||||
KT-Kernel configuration for SFT training.
|
||||
|
||||
All kt-kernel-specific settings live here. Accelerate's KTransformersPlugin
|
||||
holds a reference to this via its `kt_config` field (similar to
|
||||
DeepSpeedPlugin.hf_ds_config).
|
||||
All field names use the ``kt_`` prefix so they match the dict keys used in
|
||||
HfTrainerKTConfig / YAML configs. This means ``KTConfig(**dict)`` works
|
||||
directly — no name-mapping or prefix-stripping needed.
|
||||
|
||||
Can be created from:
|
||||
- Direct construction: KTConfig(backend="AMXBF16", weight_path="/path/...")
|
||||
- Direct construction: KTConfig(kt_backend="AMXBF16", kt_weight_path="/path/...")
|
||||
- Dict: KTConfig(**config_dict)
|
||||
- Environment variables: KTConfig() reads ACCELERATE_KT_* env vars as defaults
|
||||
"""
|
||||
|
||||
# Backend selection
|
||||
backend: str | None = None
|
||||
num_threads: int | None = None
|
||||
tp_enabled: bool | None = None
|
||||
threadpool_count: int | None = None
|
||||
kt_backend: str | None = None
|
||||
kt_num_threads: int | None = None
|
||||
kt_tp_enabled: bool | None = None
|
||||
kt_threadpool_count: int | None = None
|
||||
|
||||
# Weight loading
|
||||
weight_path: str | None = None
|
||||
expert_checkpoint_path: str | None = None
|
||||
num_gpu_experts: int | None = None
|
||||
skip_expert_loading: bool | None = None
|
||||
share_backward_bb: bool | None = None
|
||||
kt_weight_path: str | None = None
|
||||
kt_expert_checkpoint_path: str | None = None
|
||||
kt_num_gpu_experts: int | None = None
|
||||
kt_skip_expert_loading: bool | None = None
|
||||
kt_share_backward_bb: bool | None = None
|
||||
kt_share_cache_pool: bool | None = None
|
||||
|
||||
# Cache
|
||||
max_cache_depth: int | None = None
|
||||
model_max_length: int | None = None
|
||||
kt_max_cache_depth: int | None = None
|
||||
kt_model_max_length: int | None = None
|
||||
|
||||
# LoRA
|
||||
lora_rank: int | None = None
|
||||
lora_alpha: float | None = None
|
||||
kt_lora_rank: int | None = None
|
||||
kt_lora_alpha: float | None = None
|
||||
|
||||
# LoRA Experts (GPU-side extra experts)
|
||||
use_lora_experts: bool | None = None
|
||||
lora_expert_num: int | None = None
|
||||
lora_expert_intermediate_size: int | None = None
|
||||
kt_use_lora_experts: bool | None = None
|
||||
kt_lora_expert_num: int | None = None
|
||||
kt_lora_expert_intermediate_size: int | None = None
|
||||
|
||||
# Runtime state (set during wrapping, not by user)
|
||||
checkpoint_files: list[str] | None = None
|
||||
sharded_metadata: dict | None = None
|
||||
kt_checkpoint_files: list[str] | None = None
|
||||
kt_sharded_metadata: dict | None = None
|
||||
|
||||
# Custom wrapping
|
||||
wrap_fn: Callable[..., Any] | None = None
|
||||
wrap_kwargs: dict[str, Any] | None = None
|
||||
kt_wrap_fn: Callable[..., Any] | None = None
|
||||
kt_wrap_kwargs: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_object(cls, obj: Any) -> "KTConfig":
|
||||
"""Create KTConfig from an attribute-based object (HfTrainerKTConfig, etc.)."""
|
||||
_field_names = {f.name for f in dataclasses.fields(cls)}
|
||||
kwargs: dict[str, Any] = {}
|
||||
for name in _field_names:
|
||||
val = getattr(obj, name, None)
|
||||
if val is not None:
|
||||
kwargs[name] = val
|
||||
return cls(**kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.backend is None:
|
||||
self.backend = os.environ.get("ACCELERATE_KT_BACKEND", "AMXBF16")
|
||||
if self.num_threads is None:
|
||||
self.num_threads = _env_int("ACCELERATE_KT_NUM_THREADS", 1)
|
||||
if self.tp_enabled is None:
|
||||
self.tp_enabled = _env_bool("ACCELERATE_KT_TP_ENABLED", False)
|
||||
if self.threadpool_count is None:
|
||||
self.threadpool_count = _env_int("ACCELERATE_KT_THREADPOOL_COUNT", 1)
|
||||
if self.weight_path is None:
|
||||
self.weight_path = os.environ.get("ACCELERATE_KT_WEIGHT_PATH", None)
|
||||
if self.expert_checkpoint_path is None:
|
||||
self.expert_checkpoint_path = os.environ.get("ACCELERATE_KT_EXPERT_CHECKPOINT_PATH", None)
|
||||
if self.num_gpu_experts is None:
|
||||
self.num_gpu_experts = _env_int("ACCELERATE_KT_NUM_GPU_EXPERTS", 0)
|
||||
if self.max_cache_depth is None:
|
||||
self.max_cache_depth = _env_int("ACCELERATE_KT_MAX_CACHE_DEPTH", 2)
|
||||
if self.share_backward_bb is None:
|
||||
self.share_backward_bb = _env_bool("ACCELERATE_KT_SHARE_BACKWARD_BB", False)
|
||||
if self.use_lora_experts is None:
|
||||
self.use_lora_experts = _env_bool("ACCELERATE_KT_USE_LORA_EXPERTS", False)
|
||||
if self.lora_expert_num is None:
|
||||
self.lora_expert_num = _env_int("ACCELERATE_KT_LORA_EXPERT_NUM", None)
|
||||
if self.lora_expert_intermediate_size is None:
|
||||
self.lora_expert_intermediate_size = _env_int("ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE", None)
|
||||
if self.lora_rank is None:
|
||||
self.lora_rank = _env_int("ACCELERATE_KT_LORA_RANK", None)
|
||||
if self.lora_alpha is None:
|
||||
self.lora_alpha = _env_float("ACCELERATE_KT_LORA_ALPHA", None)
|
||||
if self.lora_alpha is None and self.lora_rank is not None:
|
||||
self.lora_alpha = float(self.lora_rank * 2)
|
||||
if self.model_max_length is None:
|
||||
self.model_max_length = _env_int("ACCELERATE_KT_MODEL_MAX_LENGTH", None)
|
||||
if self.skip_expert_loading is None:
|
||||
if self.kt_backend is None:
|
||||
self.kt_backend = os.environ.get("ACCELERATE_KT_BACKEND", "AMXBF16")
|
||||
if self.kt_num_threads is None:
|
||||
self.kt_num_threads = _env_int("ACCELERATE_KT_NUM_THREADS", 1)
|
||||
if self.kt_tp_enabled is None:
|
||||
self.kt_tp_enabled = _env_bool("ACCELERATE_KT_TP_ENABLED", False)
|
||||
if self.kt_threadpool_count is None:
|
||||
self.kt_threadpool_count = _env_int("ACCELERATE_KT_THREADPOOL_COUNT", 1)
|
||||
if self.kt_weight_path is None:
|
||||
self.kt_weight_path = os.environ.get("ACCELERATE_KT_WEIGHT_PATH", None)
|
||||
if self.kt_expert_checkpoint_path is None:
|
||||
self.kt_expert_checkpoint_path = os.environ.get("ACCELERATE_KT_EXPERT_CHECKPOINT_PATH", None)
|
||||
if self.kt_num_gpu_experts is None:
|
||||
self.kt_num_gpu_experts = _env_int("ACCELERATE_KT_NUM_GPU_EXPERTS", 0)
|
||||
if self.kt_max_cache_depth is None:
|
||||
self.kt_max_cache_depth = _env_int("ACCELERATE_KT_MAX_CACHE_DEPTH", 2)
|
||||
if self.kt_share_backward_bb is None:
|
||||
self.kt_share_backward_bb = _env_bool("ACCELERATE_KT_SHARE_BACKWARD_BB", False)
|
||||
if self.kt_share_cache_pool is None:
|
||||
self.kt_share_cache_pool = _env_bool("ACCELERATE_KT_SHARE_CACHE_POOL", False)
|
||||
if self.kt_use_lora_experts is None:
|
||||
self.kt_use_lora_experts = _env_bool("ACCELERATE_KT_USE_LORA_EXPERTS", False)
|
||||
if self.kt_lora_expert_num is None:
|
||||
self.kt_lora_expert_num = _env_int("ACCELERATE_KT_LORA_EXPERT_NUM", None)
|
||||
if self.kt_lora_expert_intermediate_size is None:
|
||||
self.kt_lora_expert_intermediate_size = _env_int("ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE", None)
|
||||
if self.kt_lora_rank is None:
|
||||
self.kt_lora_rank = _env_int("ACCELERATE_KT_LORA_RANK", None)
|
||||
if self.kt_lora_alpha is None:
|
||||
self.kt_lora_alpha = _env_float("ACCELERATE_KT_LORA_ALPHA", None)
|
||||
if self.kt_lora_alpha is None and self.kt_lora_rank is not None:
|
||||
self.kt_lora_alpha = float(self.kt_lora_rank * 2)
|
||||
if self.kt_model_max_length is None:
|
||||
self.kt_model_max_length = _env_int("ACCELERATE_KT_MODEL_MAX_LENGTH", None)
|
||||
if self.kt_skip_expert_loading is None:
|
||||
if "ACCELERATE_KT_SKIP_EXPERT_LOADING" in os.environ:
|
||||
self.skip_expert_loading = _env_bool("ACCELERATE_KT_SKIP_EXPERT_LOADING", True)
|
||||
self.kt_skip_expert_loading = _env_bool("ACCELERATE_KT_SKIP_EXPERT_LOADING", True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue