mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
refactor
This commit is contained in:
parent
4ef25032c1
commit
22f9a65772
3 changed files with 164 additions and 206 deletions
228
cli.py
228
cli.py
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
|
@ -6,7 +5,8 @@ from pathlib import Path
|
|||
from typing import Optional, List
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
|
||||
from cli.config import Config, load_config
|
||||
|
||||
app = typer.Typer(
|
||||
help="Command-line interface for Unsloth training, chat, and export.",
|
||||
|
|
@ -23,65 +23,6 @@ def configure_logging(verbose: bool):
|
|||
)
|
||||
|
||||
|
||||
def _load_config(config_path: Optional[Path]) -> dict:
|
||||
if not config_path:
|
||||
return {}
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise typer.BadParameter(f"Config file not found: {config_path}")
|
||||
text = path.read_text(encoding="utf-8")
|
||||
if path.suffix.lower() in {".yaml", ".yml"}:
|
||||
return yaml.safe_load(text) or {}
|
||||
else:
|
||||
return json.loads(text or "{}")
|
||||
|
||||
|
||||
def _flatten_config(cfg: dict) -> dict:
|
||||
"""
|
||||
Flatten nested config sections into a single dict.
|
||||
|
||||
Expected sections:
|
||||
data: dataset, local_dataset, format_type
|
||||
training: training_type, max_seq_length, load_in_4bit, output_dir, etc.
|
||||
lora: lora_r, lora_alpha, lora_dropout, target_modules, etc.
|
||||
vision: finetune_vision_layers, finetune_language_layers, etc.
|
||||
logging: enable_wandb, wandb_project, wandb_token, enable_tensorboard, etc.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
return {}
|
||||
|
||||
flattened = {}
|
||||
|
||||
# Handle top-level 'model' key
|
||||
if "model" in cfg:
|
||||
flattened["model"] = cfg["model"]
|
||||
|
||||
sections = ["data", "training", "lora", "vision", "logging"]
|
||||
|
||||
for section in sections:
|
||||
if section in cfg and isinstance(cfg[section], dict):
|
||||
flattened.update(cfg[section])
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
def _merge_config(cfg: dict, defaults: dict, overrides: dict) -> dict:
|
||||
"""
|
||||
Merge CLI overrides with config and defaults.
|
||||
CLI override wins, then config value, then default.
|
||||
"""
|
||||
merged = {}
|
||||
for key, default in defaults.items():
|
||||
cli_val = overrides.get(key, None)
|
||||
if cli_val is not None:
|
||||
merged[key] = cli_val
|
||||
elif key in cfg and cfg[key] is not None:
|
||||
merged[key] = cfg[key]
|
||||
else:
|
||||
merged[key] = default
|
||||
return merged
|
||||
|
||||
|
||||
@app.command()
|
||||
def train(
|
||||
model: Optional[str] = typer.Option(
|
||||
|
|
@ -198,99 +139,22 @@ def train(
|
|||
"""
|
||||
Launch training using the existing Unsloth training backend.
|
||||
"""
|
||||
cfg = _load_config(config)
|
||||
cfg = _flatten_config(cfg)
|
||||
try:
|
||||
cfg = load_config(config)
|
||||
except FileNotFoundError as e:
|
||||
typer.echo(f"Error: {e}", err=True)
|
||||
raise typer.Exit(code=2)
|
||||
|
||||
# Defaults (match previous behavior)
|
||||
defaults = {
|
||||
"model": None,
|
||||
"training_type": "lora",
|
||||
"max_seq_length": 2048,
|
||||
"load_in_4bit": True,
|
||||
"output_dir": Path("./outputs"),
|
||||
"dataset": None,
|
||||
"local_dataset": None,
|
||||
"format_type": "auto",
|
||||
"num_epochs": 3,
|
||||
"learning_rate": 2e-4,
|
||||
"batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"warmup_steps": 5,
|
||||
"max_steps": 0,
|
||||
"save_steps": 0,
|
||||
"weight_decay": 0.01,
|
||||
"random_seed": 3407,
|
||||
"packing": False,
|
||||
"train_on_completions": False,
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
"gradient_checkpointing": True,
|
||||
"target_modules": "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
|
||||
"vision_all_linear": False,
|
||||
"finetune_vision_layers": True,
|
||||
"finetune_language_layers": True,
|
||||
"finetune_attention_modules": True,
|
||||
"finetune_mlp_modules": True,
|
||||
"use_rslora": False,
|
||||
"use_loftq": False,
|
||||
"enable_wandb": False,
|
||||
"wandb_project": "unsloth-training",
|
||||
"enable_tensorboard": False,
|
||||
"tensorboard_dir": "runs",
|
||||
}
|
||||
# Apply CLI overrides
|
||||
cli_args = {k: v for k, v in locals().items() if k not in ("config", "verbose", "hf_token", "cfg")}
|
||||
cfg.apply_overrides(**cli_args)
|
||||
|
||||
overrides = {
|
||||
"training_type": training_type,
|
||||
"max_seq_length": max_seq_length,
|
||||
"load_in_4bit": load_in_4bit,
|
||||
"output_dir": output_dir,
|
||||
"dataset": dataset,
|
||||
"local_dataset": local_dataset,
|
||||
"format_type": format_type,
|
||||
"num_epochs": num_epochs,
|
||||
"learning_rate": learning_rate,
|
||||
"batch_size": batch_size,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"warmup_steps": warmup_steps,
|
||||
"max_steps": max_steps,
|
||||
"save_steps": save_steps,
|
||||
"weight_decay": weight_decay,
|
||||
"random_seed": random_seed,
|
||||
"packing": packing,
|
||||
"train_on_completions": train_on_completions,
|
||||
"lora_r": lora_r,
|
||||
"lora_alpha": lora_alpha,
|
||||
"lora_dropout": lora_dropout,
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
"target_modules": target_modules,
|
||||
"vision_all_linear": vision_all_linear,
|
||||
"finetune_vision_layers": finetune_vision_layers,
|
||||
"finetune_language_layers": finetune_language_layers,
|
||||
"finetune_attention_modules": finetune_attention_modules,
|
||||
"finetune_mlp_modules": finetune_mlp_modules,
|
||||
"use_rslora": use_rslora,
|
||||
"use_loftq": use_loftq,
|
||||
"enable_wandb": enable_wandb,
|
||||
"wandb_project": wandb_project,
|
||||
"wandb_token": wandb_token,
|
||||
"enable_tensorboard": enable_tensorboard,
|
||||
"tensorboard_dir": tensorboard_dir,
|
||||
}
|
||||
|
||||
merged = _merge_config(cfg, defaults, overrides)
|
||||
|
||||
model_val = merged.get("model")
|
||||
if not model_val:
|
||||
# Validate required fields
|
||||
if not cfg.model:
|
||||
typer.echo("Error: provide --model or set model in --config", err=True)
|
||||
raise typer.Exit(code=2)
|
||||
|
||||
# Convert specific types
|
||||
output_dir_val = Path(merged["output_dir"])
|
||||
dataset_val = merged.get("dataset")
|
||||
local_dataset_val = merged.get("local_dataset")
|
||||
|
||||
if not dataset_val and not local_dataset_val:
|
||||
if not cfg.data.dataset and not cfg.data.local_dataset:
|
||||
typer.echo(
|
||||
"Error: provide --dataset or --local-dataset (or via --config)", err=True
|
||||
)
|
||||
|
|
@ -304,86 +168,38 @@ def train(
|
|||
trainer = UnslothTrainer()
|
||||
|
||||
model_config = ModelConfig.from_ui_selection(
|
||||
dropdown_value=model_val, search_value=None, hf_token=hf_token, is_lora=False
|
||||
dropdown_value=cfg.model, search_value=None, hf_token=hf_token, is_lora=False
|
||||
)
|
||||
if not model_config:
|
||||
typer.echo("Could not resolve model config", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
is_vision = model_config.is_vision
|
||||
use_lora = cfg.training.training_type.lower() == "lora"
|
||||
|
||||
if not trainer.load_model(
|
||||
model_name=model_config.identifier,
|
||||
max_seq_length=merged["max_seq_length"],
|
||||
load_in_4bit=merged["load_in_4bit"]
|
||||
if merged["training_type"].lower() == "lora"
|
||||
else False,
|
||||
max_seq_length=cfg.training.max_seq_length,
|
||||
load_in_4bit=cfg.training.load_in_4bit if use_lora else False,
|
||||
hf_token=hf_token,
|
||||
):
|
||||
typer.echo("Model load failed", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
use_lora = merged["training_type"].lower() == "lora"
|
||||
|
||||
# Match UI behavior for target modules:
|
||||
# - Text: use parsed target modules list
|
||||
# - Vision: if vision_all_linear, use ["all-linear"]; otherwise empty list
|
||||
target_modules_list = [
|
||||
m.strip() for m in merged["target_modules"].split(",") if m.strip()
|
||||
]
|
||||
if use_lora and is_vision:
|
||||
if merged["vision_all_linear"]:
|
||||
target_modules_list = ["all-linear"]
|
||||
else:
|
||||
target_modules_list = []
|
||||
|
||||
if not trainer.prepare_model_for_training(
|
||||
use_lora=use_lora,
|
||||
finetune_vision_layers=merged["finetune_vision_layers"],
|
||||
finetune_language_layers=merged["finetune_language_layers"],
|
||||
finetune_attention_modules=merged["finetune_attention_modules"],
|
||||
finetune_mlp_modules=merged["finetune_mlp_modules"],
|
||||
target_modules=target_modules_list,
|
||||
lora_r=merged["lora_r"],
|
||||
lora_alpha=merged["lora_alpha"],
|
||||
lora_dropout=merged["lora_dropout"],
|
||||
use_gradient_checkpointing=merged["gradient_checkpointing"],
|
||||
use_rslora=merged["use_rslora"],
|
||||
use_loftq=merged["use_loftq"],
|
||||
):
|
||||
if not trainer.prepare_model_for_training(**cfg.model_kwargs(use_lora, is_vision)):
|
||||
typer.echo("Model preparation failed", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
ds = trainer.load_and_format_dataset(
|
||||
dataset_source=dataset_val or "",
|
||||
format_type=merged["format_type"],
|
||||
local_datasets=local_dataset_val,
|
||||
dataset_source=cfg.data.dataset or "",
|
||||
format_type=cfg.data.format_type,
|
||||
local_datasets=cfg.data.local_dataset,
|
||||
)
|
||||
if ds is None:
|
||||
typer.echo("Dataset load failed", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
started = trainer.start_training(
|
||||
dataset=ds,
|
||||
output_dir=str(output_dir_val),
|
||||
num_epochs=merged["num_epochs"],
|
||||
learning_rate=merged["learning_rate"],
|
||||
batch_size=merged["batch_size"],
|
||||
gradient_accumulation_steps=merged["gradient_accumulation_steps"],
|
||||
warmup_steps=merged["warmup_steps"],
|
||||
max_steps=merged["max_steps"],
|
||||
save_steps=merged["save_steps"],
|
||||
weight_decay=merged["weight_decay"],
|
||||
random_seed=merged["random_seed"],
|
||||
packing=merged["packing"],
|
||||
train_on_completions=merged["train_on_completions"],
|
||||
enable_wandb=merged["enable_wandb"],
|
||||
wandb_project=merged["wandb_project"],
|
||||
wandb_token=merged.get("wandb_token"),
|
||||
enable_tensorboard=merged["enable_tensorboard"],
|
||||
tensorboard_dir=merged["tensorboard_dir"],
|
||||
max_seq_length=merged["max_seq_length"],
|
||||
)
|
||||
started = trainer.start_training(dataset=ds, **cfg.training_kwargs())
|
||||
|
||||
if not started:
|
||||
typer.echo("Training failed to start", err=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue