mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 04:09:52 +00:00
refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code
- KTConfig fields all use kt_ prefix matching dict keys — eliminates _OLD_TO_NEW mapping and prefix-stripping in wrapper.py - Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing is on (via training_args.py), flows through to C++ cache allocation - Remove dead checkpoint detection code: in_ckpt_recompute, in_ckpt_first_forward vars (assigned but never read), fallback _is_in_checkpoint_first_forward() function, unused inspect import - Remove redundant env var fallbacks in wrapper.py for share_backward_bb and share_cache_pool (KTConfig.__post_init__ already handles env vars) - Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check Verified: Qwen3-235B 3-step training on sap4, loss matches baseline (1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a98d544833
commit
020eb929f7
5 changed files with 127 additions and 172 deletions
|
|
@ -51,47 +51,19 @@ else:
|
|||
def _get_kt_config(kt_plugin: Any):
|
||||
"""Extract KTConfig from a KTransformersPlugin or compatible object.
|
||||
|
||||
Handles three cases:
|
||||
1. KTransformersPlugin with .kt_config (new style) → return kt_config
|
||||
2. Object with old field names (kt_num_threads etc.) → convert to KTConfig
|
||||
3. KTConfig directly → return as-is
|
||||
KTConfig field names use kt_ prefix, matching the dict keys in
|
||||
HfTrainerKTConfig exactly — no name-mapping needed.
|
||||
"""
|
||||
from .config import KTConfig
|
||||
|
||||
# New-style KTransformersPlugin
|
||||
if isinstance(kt_plugin, KTConfig):
|
||||
return kt_plugin
|
||||
|
||||
kt_config = getattr(kt_plugin, "kt_config", None)
|
||||
if kt_config is not None and isinstance(kt_config, KTConfig):
|
||||
return kt_config
|
||||
|
||||
# Already a KTConfig
|
||||
if isinstance(kt_plugin, KTConfig):
|
||||
return kt_plugin
|
||||
|
||||
# Old-style object (HfTrainerKTConfig, old KTransformersPlugin, dict-like) — convert
|
||||
# Map old field names (kt_xxx) to new field names (xxx)
|
||||
_OLD_TO_NEW = {
|
||||
"kt_backend": "backend", "kt_num_threads": "num_threads",
|
||||
"kt_tp_enabled": "tp_enabled", "kt_threadpool_count": "threadpool_count",
|
||||
"kt_weight_path": "weight_path", "kt_expert_checkpoint_path": "expert_checkpoint_path",
|
||||
"kt_num_gpu_experts": "num_gpu_experts", "kt_max_cache_depth": "max_cache_depth",
|
||||
"kt_use_lora_experts": "use_lora_experts", "kt_lora_expert_num": "lora_expert_num",
|
||||
"kt_lora_expert_intermediate_size": "lora_expert_intermediate_size",
|
||||
"kt_skip_expert_loading": "skip_expert_loading",
|
||||
"kt_share_backward_bb": "share_backward_bb",
|
||||
"kt_checkpoint_files": "checkpoint_files",
|
||||
"kt_sharded_metadata": "sharded_metadata",
|
||||
}
|
||||
kwargs = {}
|
||||
for old_name, new_name in _OLD_TO_NEW.items():
|
||||
val = getattr(kt_plugin, old_name, None)
|
||||
if val is not None:
|
||||
kwargs[new_name] = val
|
||||
# Fields that don't have kt_ prefix
|
||||
for name in ("lora_rank", "lora_alpha", "model_max_length", "wrap_fn", "wrap_kwargs"):
|
||||
val = getattr(kt_plugin, name, None)
|
||||
if val is not None:
|
||||
kwargs[name] = val
|
||||
return KTConfig(**kwargs)
|
||||
return KTConfig.from_object(kt_plugin)
|
||||
|
||||
|
||||
def build_kt_device_map(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]:
|
||||
|
|
@ -103,7 +75,7 @@ def build_kt_device_map(config, kt_plugin, device: str = "cuda:0") -> dict[str,
|
|||
num_layers = config.num_hidden_layers
|
||||
num_experts = moe_config.expert_num
|
||||
cfg = _get_kt_config(kt_plugin)
|
||||
num_gpu_experts = getattr(cfg, "num_gpu_experts", 0) or 0
|
||||
num_gpu_experts = getattr(cfg, "kt_num_gpu_experts", 0) or 0
|
||||
|
||||
device_map: dict[str, str | int] = {}
|
||||
|
||||
|
|
@ -138,7 +110,7 @@ def build_kt_device_map_simplified(config, kt_plugin, device: str = "cuda:0") ->
|
|||
layers_prefix = _get_layers_prefix(config)
|
||||
num_layers = config.num_hidden_layers
|
||||
cfg = _get_kt_config(kt_plugin)
|
||||
num_gpu_experts = getattr(cfg, "num_gpu_experts", 0) or 0
|
||||
num_gpu_experts = getattr(cfg, "kt_num_gpu_experts", 0) or 0
|
||||
|
||||
device_map: dict[str, str | int] = {}
|
||||
|
||||
|
|
@ -190,14 +162,14 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
cfg = _get_kt_config(kt_plugin)
|
||||
|
||||
# Read lora_rank/lora_alpha for C++ wrapper initialization (buffer allocation only)
|
||||
lora_rank = getattr(cfg, "lora_rank", 1) or 1
|
||||
lora_alpha = getattr(cfg, "lora_alpha", 1.0) or 1.0
|
||||
lora_rank = getattr(cfg, "kt_lora_rank", 1) or 1
|
||||
lora_alpha = getattr(cfg, "kt_lora_alpha", 1.0) or 1.0
|
||||
|
||||
# Read LoRA Experts configuration
|
||||
_raw_le = getattr(cfg, "use_lora_experts", None)
|
||||
_raw_le = getattr(cfg, "kt_use_lora_experts", None)
|
||||
use_lora_experts = bool(_raw_le) if _raw_le is not None else False
|
||||
lora_expert_num = getattr(cfg, "lora_expert_num", 2) or 2
|
||||
lora_expert_intermediate_size = getattr(cfg, "lora_expert_intermediate_size", 1024) or 1024
|
||||
lora_expert_num = getattr(cfg, "kt_lora_expert_num", 2) or 2
|
||||
lora_expert_intermediate_size = getattr(cfg, "kt_lora_expert_intermediate_size", 1024) or 1024
|
||||
|
||||
if is_rank_0:
|
||||
logger.info(
|
||||
|
|
@ -218,7 +190,7 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
}
|
||||
# Build case-insensitive lookup to handle common typos like "SkipLora" vs "SkipLoRA"
|
||||
_kt_backend_map_lower = {k.lower(): v for k, v in kt_backend_map.items()}
|
||||
kt_backend = getattr(cfg, "backend", "AMXBF16")
|
||||
kt_backend = getattr(cfg, "kt_backend", "AMXBF16")
|
||||
kt_method = kt_backend_map.get(kt_backend) or _kt_backend_map_lower.get(kt_backend.lower(), "AMXBF16_SFT")
|
||||
if kt_method != kt_backend_map.get(kt_backend):
|
||||
logger.warning(
|
||||
|
|
@ -229,27 +201,27 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
if "SkipLoRA" in kt_method:
|
||||
logger.info(f"Using SkipLoRA backend: {kt_method} (MoE LoRA gradients will be skipped)")
|
||||
|
||||
threadpool_count = getattr(cfg, "threadpool_count", 1) if getattr(cfg, "tp_enabled", False) else 1
|
||||
threadpool_count = getattr(cfg, "kt_threadpool_count", 1) if getattr(cfg, "kt_tp_enabled", False) else 1
|
||||
|
||||
kt_weight_path = getattr(cfg, "weight_path", None)
|
||||
kt_weight_path = getattr(cfg, "kt_weight_path", None)
|
||||
use_kt_weight_path = kt_weight_path is not None
|
||||
if use_kt_weight_path:
|
||||
logger.info(f"Loading INT8 weights from kt_weight_path: {kt_weight_path}")
|
||||
|
||||
checkpoint_files = getattr(cfg, "checkpoint_files", None)
|
||||
sharded_metadata = getattr(cfg, "sharded_metadata", None)
|
||||
checkpoint_files = getattr(cfg, "kt_checkpoint_files", None)
|
||||
sharded_metadata = getattr(cfg, "kt_sharded_metadata", None)
|
||||
|
||||
# When kt_expert_checkpoint_path is set, always resolve from it (overrides any existing
|
||||
# checkpoint_files which may come from AttnOnlyBf16 and lack expert weights).
|
||||
kt_expert_checkpoint_path = getattr(cfg, "expert_checkpoint_path", None)
|
||||
kt_expert_checkpoint_path = getattr(cfg, "kt_expert_checkpoint_path", None)
|
||||
if kt_expert_checkpoint_path:
|
||||
logger.info(f"Resolving expert checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}")
|
||||
resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=kt_expert_checkpoint_path)
|
||||
if resolved_files and all(f.endswith(".safetensors") for f in resolved_files):
|
||||
checkpoint_files = resolved_files
|
||||
sharded_metadata = resolved_meta
|
||||
cfg.checkpoint_files = checkpoint_files
|
||||
cfg.sharded_metadata = sharded_metadata
|
||||
cfg.kt_checkpoint_files = checkpoint_files
|
||||
cfg.kt_sharded_metadata = sharded_metadata
|
||||
logger.info(f"Resolved {len(checkpoint_files)} checkpoint files from kt_expert_checkpoint_path")
|
||||
else:
|
||||
logger.warning(f"Failed to resolve checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}")
|
||||
|
|
@ -267,7 +239,7 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
logger.info("Loading expert weights from checkpoint files (online conversion).")
|
||||
elif use_kt_weight_path and bool(checkpoint_files):
|
||||
logger.info("BF16 checkpoint files available for backward gradient computation.")
|
||||
elif (not use_kt_weight_path) and bool(getattr(cfg, "skip_expert_loading", False)):
|
||||
elif (not use_kt_weight_path) and bool(getattr(cfg, "kt_skip_expert_loading", False)):
|
||||
# If HF expert weights were skipped during `from_pretrained`, we must source expert weights externally.
|
||||
model_name_or_path = getattr(getattr(model, "config", None), "name_or_path", None)
|
||||
if model_name_or_path:
|
||||
|
|
@ -275,8 +247,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
if resolved_files and all(f.endswith(".safetensors") for f in resolved_files):
|
||||
checkpoint_files = resolved_files
|
||||
sharded_metadata = resolved_meta
|
||||
cfg.checkpoint_files = checkpoint_files
|
||||
cfg.sharded_metadata = sharded_metadata
|
||||
cfg.kt_checkpoint_files = checkpoint_files
|
||||
cfg.kt_sharded_metadata = sharded_metadata
|
||||
use_checkpoint_files = True
|
||||
logger.info("KT skip_expert_loading enabled; using checkpoint files for online expert loading.")
|
||||
|
||||
|
|
@ -328,7 +300,7 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
up_proj = up_proj.cpu().to(torch.bfloat16).contiguous()
|
||||
down_proj = down_proj.cpu().to(torch.bfloat16).contiguous()
|
||||
|
||||
chunked_prefill_size = getattr(cfg, "model_max_length", None)
|
||||
chunked_prefill_size = getattr(cfg, "kt_model_max_length", None)
|
||||
if chunked_prefill_size is None:
|
||||
chunked_prefill_size = getattr(model.config, "max_position_embeddings", 4096)
|
||||
|
||||
|
|
@ -341,7 +313,7 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_config.intermediate_size,
|
||||
num_gpu_experts=0,
|
||||
cpuinfer_threads=getattr(cfg, "num_threads", 1),
|
||||
cpuinfer_threads=getattr(cfg, "kt_num_threads", 1),
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=kt_weight_path or "",
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
|
|
@ -349,14 +321,12 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
|
|||
mode="sft",
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
max_cache_depth=getattr(cfg, "max_cache_depth", 2),
|
||||
max_cache_depth=getattr(cfg, "kt_max_cache_depth", 2),
|
||||
)
|
||||
|
||||
# Set share_backward_bb BEFORE load_weights (config is built during load)
|
||||
share_backward_bb = getattr(cfg, "share_backward_bb", None)
|
||||
if share_backward_bb is None:
|
||||
share_backward_bb = os.environ.get("ACCELERATE_KT_SHARE_BACKWARD_BB", "").lower() in ("true", "1", "yes")
|
||||
wrapper.share_backward_bb = share_backward_bb
|
||||
# Set share_backward_bb and share_cache_pool BEFORE load_weights (config is built during load)
|
||||
wrapper.share_backward_bb = cfg.kt_share_backward_bb
|
||||
wrapper.share_cache_pool = cfg.kt_share_cache_pool
|
||||
|
||||
physical_to_logical_map = torch.arange(moe_config.expert_num, dtype=torch.int64, device="cpu")
|
||||
|
||||
|
|
@ -443,20 +413,20 @@ def _build_kt_plugin_from_args(model_args: Any, finetuning_args: Any | None = No
|
|||
from accelerate.utils.dataclasses import KTransformersPlugin
|
||||
|
||||
kt_config = KTConfig(
|
||||
backend=getattr(model_args, "kt_backend", None),
|
||||
num_threads=getattr(model_args, "kt_num_threads", None),
|
||||
tp_enabled=getattr(model_args, "kt_tp_enabled", None),
|
||||
threadpool_count=getattr(model_args, "kt_threadpool_count", None),
|
||||
max_cache_depth=getattr(model_args, "kt_max_cache_depth", None),
|
||||
num_gpu_experts=getattr(model_args, "kt_num_gpu_experts", None),
|
||||
weight_path=getattr(model_args, "kt_weight_path", None),
|
||||
expert_checkpoint_path=getattr(model_args, "kt_expert_checkpoint_path", None),
|
||||
use_lora_experts=getattr(model_args, "kt_use_lora_experts", None),
|
||||
lora_expert_num=getattr(model_args, "kt_lora_expert_num", None),
|
||||
lora_expert_intermediate_size=getattr(model_args, "kt_lora_expert_intermediate_size", None),
|
||||
lora_rank=getattr(finetuning_args, "lora_rank", None) if finetuning_args else None,
|
||||
lora_alpha=getattr(finetuning_args, "lora_alpha", None) if finetuning_args else None,
|
||||
model_max_length=getattr(model_args, "model_max_length", None),
|
||||
kt_backend=getattr(model_args, "kt_backend", None),
|
||||
kt_num_threads=getattr(model_args, "kt_num_threads", None),
|
||||
kt_tp_enabled=getattr(model_args, "kt_tp_enabled", None),
|
||||
kt_threadpool_count=getattr(model_args, "kt_threadpool_count", None),
|
||||
kt_max_cache_depth=getattr(model_args, "kt_max_cache_depth", None),
|
||||
kt_num_gpu_experts=getattr(model_args, "kt_num_gpu_experts", None),
|
||||
kt_weight_path=getattr(model_args, "kt_weight_path", None),
|
||||
kt_expert_checkpoint_path=getattr(model_args, "kt_expert_checkpoint_path", None),
|
||||
kt_use_lora_experts=getattr(model_args, "kt_use_lora_experts", None),
|
||||
kt_lora_expert_num=getattr(model_args, "kt_lora_expert_num", None),
|
||||
kt_lora_expert_intermediate_size=getattr(model_args, "kt_lora_expert_intermediate_size", None),
|
||||
kt_lora_rank=getattr(finetuning_args, "lora_rank", None) if finetuning_args else None,
|
||||
kt_lora_alpha=getattr(finetuning_args, "lora_alpha", None) if finetuning_args else None,
|
||||
kt_model_max_length=getattr(model_args, "model_max_length", None),
|
||||
)
|
||||
return KTransformersPlugin(enabled=True, kt_config=kt_config)
|
||||
|
||||
|
|
@ -570,21 +540,21 @@ def load_kt_model(
|
|||
|
||||
cfg = _get_kt_config(kt_plugin)
|
||||
|
||||
if getattr(cfg, "skip_expert_loading", None) is None:
|
||||
if getattr(cfg, "kt_skip_expert_loading", None) is None:
|
||||
checkpoint_files, sharded_metadata = _resolve_checkpoint_files(
|
||||
model_name_or_path=model_name_or_path,
|
||||
cache_dir=cache_dir, revision=revision,
|
||||
token=token, trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if checkpoint_files and all(f.endswith(".safetensors") for f in checkpoint_files):
|
||||
if getattr(cfg, "weight_path", None) is None:
|
||||
cfg.skip_expert_loading = True
|
||||
if getattr(cfg, "kt_weight_path", None) is None:
|
||||
cfg.kt_skip_expert_loading = True
|
||||
else:
|
||||
cfg.skip_expert_loading = False
|
||||
cfg.checkpoint_files = checkpoint_files
|
||||
cfg.sharded_metadata = sharded_metadata
|
||||
cfg.kt_skip_expert_loading = False
|
||||
cfg.kt_checkpoint_files = checkpoint_files
|
||||
cfg.kt_sharded_metadata = sharded_metadata
|
||||
else:
|
||||
cfg.skip_expert_loading = False
|
||||
cfg.kt_skip_expert_loading = False
|
||||
|
||||
set_kt_config(kt_plugin)
|
||||
try:
|
||||
|
|
@ -603,8 +573,8 @@ def load_kt_model(
|
|||
wrappers = wrap_moe_layers_with_kt_wrapper(model, kt_plugin)
|
||||
|
||||
model._kt_wrappers = wrappers
|
||||
model._kt_tp_enabled = bool(getattr(cfg, "tp_enabled", False))
|
||||
model._kt_use_lora_experts = bool(getattr(cfg, "use_lora_experts", False))
|
||||
model._kt_tp_enabled = bool(getattr(cfg, "kt_tp_enabled", False))
|
||||
model._kt_use_lora_experts = bool(getattr(cfg, "kt_use_lora_experts", False))
|
||||
|
||||
logger.info("Model loaded with KTMoEWrapper backend successfully")
|
||||
return model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue