kvcache-ai-ktransformers/KT-SFT/ktransformers/sft/lora.py
2025-11-04 04:24:30 +00:00

497 lines
No EOL
20 KiB
Python

from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import Trainer
from transformers.training_args import OptimizerNames
from transformers.trainer_utils import seed_worker
from transformers.utils import (
is_datasets_available,
is_sagemaker_mp_enabled,
is_torch_xpu_available,
is_torch_mlu_available,
is_torch_musa_available,
is_torch_npu_available,
is_torch_mps_available,
is_torch_hpu_available,
is_accelerate_available,
is_apex_available,
logging,
)
from packaging import version
import os
import inspect
import functools
from typing import Union, Any, Dict, List
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data import Dataset as TorchDataset
from peft import LoraConfig, TaskType
from datasets import Dataset
from torchviz import make_dot
from tqdm import tqdm
import os, json
from pathlib import Path
from accelerate import Accelerator
if is_accelerate_available("0.28.0"):
from accelerate.utils import DataLoaderConfiguration
from accelerate import __version__ as accelerate_version
if version.parse(accelerate_version) > version.parse("1.3.0"):
from accelerate.utils import TorchTensorParallelPlugin
if is_sagemaker_mp_enabled():
from transformers.trainer_utils import smp_forward_backward
from ktransformers.sft.peft_utils.mapping import get_peft_model
logger = logging.get_logger(__name__)
class KAccelerator(Accelerator):
def __init__(self, *args, **kwargs):
kwargs.setdefault("device_placement", False)
super().__init__(*args, **kwargs)
def prepare_model(self, model, *args, **kwargs):
return model
def prepare(self, *args, **kwargs):
prepped = []
for obj in args:
if isinstance(obj, nn.Module):
prepped.append(self.prepare_model(obj, **kwargs))
else:
prepped.append(super().prepare(obj, **kwargs))
return tuple(prepped) if len(prepped) > 1 else prepped[0]
class KTrainer(Trainer):
def save_model(self, output_dir=None, _internal_call=False):
output_dir = output_dir or self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
# only save LoRA adapter, including adapter_config.json
self.model.save_pretrained(output_dir)
def _move_model_to_device(self, model, device):
print("[KTrainer] Due to the placement feature in KTransformers, skip moving model to", device)
return model
def _wrap_model(self, model, training=True, dataloader=None):
self.model_wrapped = model
return model
def create_accelerator_and_postprocess(self):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
else:
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
accelerator_config = self.args.accelerator_config.to_dict()
if is_accelerate_available("0.28.0"):
# Extract dataloader config params from accelerator config
dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"]
dataloader_config_dict = {param: accelerator_config.pop(param) for param in dataloader_params if param in accelerator_config}
if DataLoaderConfiguration is None:
raise ImportError("Your accelerate does not provide DataLoaderConfiguration but Trainer expects it.")
dataloader_config = DataLoaderConfiguration(**dataloader_config_dict)
if is_accelerate_available("1.1.0"):
dataloader_config.data_seed = self.args.data_seed
else:
dataloader_config = None
non_blocking = accelerator_config.pop("non_blocking", False)
if not is_accelerate_available("0.30.0"):
if non_blocking:
raise ImportError(
"`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature."
)
else:
if non_blocking and not self.args.dataloader_pin_memory:
logger.warning("`non_blocking` is enabled but `dataloader_pin_memory` is not. For best performance, enable both.")
if dataloader_config is not None:
dataloader_config.non_blocking = non_blocking
accelerator_config.pop("gradient_accumulation_kwargs", None)
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"device_placement": False,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
else:
args.update(accelerator_config)
if getattr(self.args, "tp_size", 1) > 1:
self.is_tp_enabled = True
if version.parse(accelerate_version) > version.parse("1.3.0") and TorchTensorParallelPlugin is not None:
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
else:
raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")
self.accelerator = KAccelerator(**args)
try:
self.accelerator.state.device_ids = [0]
self.accelerator.state.num_processes = 1
self.accelerator.state.num_gpus = 1
except Exception:
pass
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
self.gather_function = functools.partial(
self.gather_function, use_gather_object=self.args.eval_use_gather_object
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
for param in ["limit_all_gathers", "activation_checkpointing"]:
setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
self.propagate_args_to_deepspeed()
# `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
if (
self.args.save_only_model
and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
and self.args.load_best_model_at_end
):
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
# `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
if (
self.is_deepspeed_enabled
and self.accelerator.state.deepspeed_plugin.zero_stage == 3
and self.args.auto_find_batch_size
):
raise ValueError(
"`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
)
if (
self.args.save_only_model
and self.is_fsdp_enabled
and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)
):
raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'")
if dataloader_config is not None:
dataloader_config.split_batches = False
dataloader_config.dispatch_batches = False
dataloader_config.even_batches = False
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training DataLoader with per_device_train_batch_size
(no implicit multipliers by number of visible GPUs).
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available():
try:
import datasets
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
except Exception:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
dataloader_params = {
"batch_size": self.args.per_device_train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
if self.args.dataloader_num_workers > 0 and self.args.dataloader_prefetch_factor is not None:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
dl = DataLoader(train_dataset, **dataloader_params)
try:
prepared = self.accelerator.prepare(dl, device_placement=[False])
except TypeError:
prepared = self.accelerator.prepare(dl)
return prepared
def training_step(
self,
model: torch.nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None
) -> torch.Tensor:
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
elif is_torch_hpu_available():
logger.warning(
"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
)
else:
torch.cuda.empty_cache()
kwargs = {}
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean()
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss: # type: ignore
scaled_loss.backward()
else:
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
if getattr(self.accelerator, "distributed_type", None) and \
str(self.accelerator.distributed_type) == "DistributedType.DEEPSPEED":
kwargs["scale_wrt_gas"] = False
self.accelerator.backward(loss, **kwargs)
ret = loss.detach()
if ret.device != self.args.device:
ret = ret.to(self.args.device, non_blocking=True)
if os.environ.get("KT_DBG_STEP", "0") == "1" and not hasattr(self, "_kt_dbg_once"):
try:
print(f"[KT-DBG] args.device={self.args.device} loss(before)={loss.device} loss(return)={ret.device}")
except Exception:
pass
self._kt_dbg_once = True
return ret
class SFTJsonListDataset(TorchDataset):
def __init__(self, path: str, tokenizer: AutoTokenizer, max_len: int = 512):
super().__init__()
with open(path, "r", encoding="utf-8") as f:
self.samples: List[Dict] = json.load(f)
self.tok = tokenizer
self.max_len = max_len
@staticmethod
def build_example(ins: str, inp: str, out: str) -> Dict[str, str]:
ins = (ins or "").strip()
inp = (inp or "").strip()
out = (out or "").strip()
prompt = (ins + inp) if ins else inp
return {"prompt": prompt, "response": out}
def __len__(self):
return len(self.samples)
def __getitem__(self, idx: int):
rec = self.samples[idx]
eg = self.build_example(rec.get("instruction", ""), rec.get("input", ""), rec.get("output", ""))
prompt_ids = self.tok(
eg["prompt"],
max_length=self.max_len,
truncation=True,
add_special_tokens=False,
)["input_ids"]
response_ids = self.tok(
eg["response"],
max_length=self.max_len,
truncation=True,
add_special_tokens=False,
)["input_ids"]
eos_id = self.tok.eos_token_id
input_ids = prompt_ids + response_ids + ([eos_id] if eos_id is not None else [])
input_ids = input_ids[: self.max_len]
labels = [-100] * min(len(prompt_ids), self.max_len)
tail = input_ids[len(labels):]
labels = labels + tail
labels = labels[: self.max_len]
attention_mask = [1] * len(input_ids)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
}
def lora_and_load_adapter(model, tokenizer, sft_data_path, save_adapter_path):
Path(save_adapter_path).mkdir(parents=True, exist_ok=True)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=[
"q_proj", # FOR DeepSeek-V2-Lite
"q_a_proj", # FOR DeepSeek-V3&R1
"q_b_proj",
"kv_a_proj_with_mqa",
"kv_b_proj",
"o_proj",
"mlp.gate_proj",
"mlp.up_proj",
"mlp.down_proj",
"shared_experts.gate_proj",
"shared_experts.up_proj",
"shared_experts.down_proj",
],
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
train_dataset = SFTJsonListDataset(sft_data_path, tokenizer, max_len=512)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=save_adapter_path,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
num_train_epochs=1,
# max_steps=30, # TODO: FOR TEST, will override any value given in num_train_epochs
learning_rate=1e-4,
fp16=False,
logging_steps=10,
save_steps=200,
dataloader_drop_last=True,
ddp_find_unused_parameters=False,
)
debug_path = os.path.join(save_adapter_path, "model_infra_debug.json")
with open(debug_path, "w", encoding="utf-8") as f:
json.dump({"model": str(model)}, f, ensure_ascii=False, indent=2)
# output = model(input_ids=torch.tensor([[1,2,3]], dtype=torch.int32, device="cuda:0"))
# loss = output.logits.mean()
# dot = make_dot(loss, params=dict(model.named_parameters()))
# dot.render("KT_compute_cpuinfer_moe_model_graph", format="svg")
trainer = KTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
model.config.use_cache = False
# model.gradient_checkpointing_enable()
# if hasattr(model, "enable_input_require_grads"):
# model.enable_input_require_grads()
trainer.train()
def inject_lora_layer(model, use_adapter_path):
cfg_path = os.path.join(use_adapter_path, "adapter_config.json")
with open(cfg_path, "r", encoding="utf-8") as f:
data = json.load(f)
task_type_str = (data.get("task_type") or "CAUSAL_LM").upper()
bias = data.get("bias", "none")
if bias in (None, False):
bias = "none"
if data.get("lora_bias") is True and bias == "none":
bias = "lora_only"
tmods = data.get("target_modules")
if isinstance(tmods, str):
tmods = [m.strip() for m in tmods.split(",") if m.strip()]
mts = data.get("modules_to_save", None)
if isinstance(mts, str):
mts = [m.strip() for m in mts.split(",") if m.strip()]
rank_pattern = data.get("rank_pattern") or None
alpha_pattern = data.get("alpha_pattern") or None
lora_config = LoraConfig(
r=data.get("r", 8),
lora_alpha=data.get("lora_alpha", 32),
lora_dropout=float(data.get("lora_dropout", 0.0)),
bias=bias,
task_type=TaskType[task_type_str],
target_modules=tmods,
modules_to_save=mts,
init_lora_weights=bool(data.get("init_lora_weights", True)),
inference_mode=bool(data.get("inference_mode", True)),
use_rslora=bool(data.get("use_rslora", False)),
use_dora=bool(data.get("use_dora", False)),
)
print(f"lora_config:{lora_config.__dict__}")
# model = inject_adapter_in_model(lora_config, model)
model = get_peft_model(model, lora_config)
model.config.use_cache = False
model.eval()