This commit is contained in:
Dan Saunders 2025-12-11 11:52:08 -05:00
parent 4ef25032c1
commit 22f9a65772
3 changed files with 164 additions and 206 deletions

228
cli.py
View file

@ -1,4 +1,3 @@
import json
import logging
import sys
import time
@ -6,7 +5,8 @@ from pathlib import Path
from typing import Optional, List
import typer
import yaml
from cli.config import Config, load_config
app = typer.Typer(
help="Command-line interface for Unsloth training, chat, and export.",
@ -23,65 +23,6 @@ def configure_logging(verbose: bool):
)
def _load_config(config_path: Optional[Path]) -> dict:
if not config_path:
return {}
path = Path(config_path)
if not path.exists():
raise typer.BadParameter(f"Config file not found: {config_path}")
text = path.read_text(encoding="utf-8")
if path.suffix.lower() in {".yaml", ".yml"}:
return yaml.safe_load(text) or {}
else:
return json.loads(text or "{}")
def _flatten_config(cfg: dict) -> dict:
"""
Flatten nested config sections into a single dict.
Expected sections:
data: dataset, local_dataset, format_type
training: training_type, max_seq_length, load_in_4bit, output_dir, etc.
lora: lora_r, lora_alpha, lora_dropout, target_modules, etc.
vision: finetune_vision_layers, finetune_language_layers, etc.
logging: enable_wandb, wandb_project, wandb_token, enable_tensorboard, etc.
"""
if not isinstance(cfg, dict):
return {}
flattened = {}
# Handle top-level 'model' key
if "model" in cfg:
flattened["model"] = cfg["model"]
sections = ["data", "training", "lora", "vision", "logging"]
for section in sections:
if section in cfg and isinstance(cfg[section], dict):
flattened.update(cfg[section])
return flattened
def _merge_config(cfg: dict, defaults: dict, overrides: dict) -> dict:
"""
Merge CLI overrides with config and defaults.
CLI override wins, then config value, then default.
"""
merged = {}
for key, default in defaults.items():
cli_val = overrides.get(key, None)
if cli_val is not None:
merged[key] = cli_val
elif key in cfg and cfg[key] is not None:
merged[key] = cfg[key]
else:
merged[key] = default
return merged
@app.command()
def train(
model: Optional[str] = typer.Option(
@ -198,99 +139,22 @@ def train(
"""
Launch training using the existing Unsloth training backend.
"""
cfg = _load_config(config)
cfg = _flatten_config(cfg)
try:
cfg = load_config(config)
except FileNotFoundError as e:
typer.echo(f"Error: {e}", err=True)
raise typer.Exit(code=2)
# Defaults (match previous behavior)
defaults = {
"model": None,
"training_type": "lora",
"max_seq_length": 2048,
"load_in_4bit": True,
"output_dir": Path("./outputs"),
"dataset": None,
"local_dataset": None,
"format_type": "auto",
"num_epochs": 3,
"learning_rate": 2e-4,
"batch_size": 2,
"gradient_accumulation_steps": 4,
"warmup_steps": 5,
"max_steps": 0,
"save_steps": 0,
"weight_decay": 0.01,
"random_seed": 3407,
"packing": False,
"train_on_completions": False,
"lora_r": 64,
"lora_alpha": 16,
"lora_dropout": 0.0,
"gradient_checkpointing": True,
"target_modules": "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
"vision_all_linear": False,
"finetune_vision_layers": True,
"finetune_language_layers": True,
"finetune_attention_modules": True,
"finetune_mlp_modules": True,
"use_rslora": False,
"use_loftq": False,
"enable_wandb": False,
"wandb_project": "unsloth-training",
"enable_tensorboard": False,
"tensorboard_dir": "runs",
}
# Apply CLI overrides
cli_args = {k: v for k, v in locals().items() if k not in ("config", "verbose", "hf_token", "cfg")}
cfg.apply_overrides(**cli_args)
overrides = {
"training_type": training_type,
"max_seq_length": max_seq_length,
"load_in_4bit": load_in_4bit,
"output_dir": output_dir,
"dataset": dataset,
"local_dataset": local_dataset,
"format_type": format_type,
"num_epochs": num_epochs,
"learning_rate": learning_rate,
"batch_size": batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
"warmup_steps": warmup_steps,
"max_steps": max_steps,
"save_steps": save_steps,
"weight_decay": weight_decay,
"random_seed": random_seed,
"packing": packing,
"train_on_completions": train_on_completions,
"lora_r": lora_r,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
"gradient_checkpointing": gradient_checkpointing,
"target_modules": target_modules,
"vision_all_linear": vision_all_linear,
"finetune_vision_layers": finetune_vision_layers,
"finetune_language_layers": finetune_language_layers,
"finetune_attention_modules": finetune_attention_modules,
"finetune_mlp_modules": finetune_mlp_modules,
"use_rslora": use_rslora,
"use_loftq": use_loftq,
"enable_wandb": enable_wandb,
"wandb_project": wandb_project,
"wandb_token": wandb_token,
"enable_tensorboard": enable_tensorboard,
"tensorboard_dir": tensorboard_dir,
}
merged = _merge_config(cfg, defaults, overrides)
model_val = merged.get("model")
if not model_val:
# Validate required fields
if not cfg.model:
typer.echo("Error: provide --model or set model in --config", err=True)
raise typer.Exit(code=2)
# Convert specific types
output_dir_val = Path(merged["output_dir"])
dataset_val = merged.get("dataset")
local_dataset_val = merged.get("local_dataset")
if not dataset_val and not local_dataset_val:
if not cfg.data.dataset and not cfg.data.local_dataset:
typer.echo(
"Error: provide --dataset or --local-dataset (or via --config)", err=True
)
@ -304,86 +168,38 @@ def train(
trainer = UnslothTrainer()
model_config = ModelConfig.from_ui_selection(
dropdown_value=model_val, search_value=None, hf_token=hf_token, is_lora=False
dropdown_value=cfg.model, search_value=None, hf_token=hf_token, is_lora=False
)
if not model_config:
typer.echo("Could not resolve model config", err=True)
raise typer.Exit(code=1)
is_vision = model_config.is_vision
use_lora = cfg.training.training_type.lower() == "lora"
if not trainer.load_model(
model_name=model_config.identifier,
max_seq_length=merged["max_seq_length"],
load_in_4bit=merged["load_in_4bit"]
if merged["training_type"].lower() == "lora"
else False,
max_seq_length=cfg.training.max_seq_length,
load_in_4bit=cfg.training.load_in_4bit if use_lora else False,
hf_token=hf_token,
):
typer.echo("Model load failed", err=True)
raise typer.Exit(code=1)
use_lora = merged["training_type"].lower() == "lora"
# Match UI behavior for target modules:
# - Text: use parsed target modules list
# - Vision: if vision_all_linear, use ["all-linear"]; otherwise empty list
target_modules_list = [
m.strip() for m in merged["target_modules"].split(",") if m.strip()
]
if use_lora and is_vision:
if merged["vision_all_linear"]:
target_modules_list = ["all-linear"]
else:
target_modules_list = []
if not trainer.prepare_model_for_training(
use_lora=use_lora,
finetune_vision_layers=merged["finetune_vision_layers"],
finetune_language_layers=merged["finetune_language_layers"],
finetune_attention_modules=merged["finetune_attention_modules"],
finetune_mlp_modules=merged["finetune_mlp_modules"],
target_modules=target_modules_list,
lora_r=merged["lora_r"],
lora_alpha=merged["lora_alpha"],
lora_dropout=merged["lora_dropout"],
use_gradient_checkpointing=merged["gradient_checkpointing"],
use_rslora=merged["use_rslora"],
use_loftq=merged["use_loftq"],
):
if not trainer.prepare_model_for_training(**cfg.model_kwargs(use_lora, is_vision)):
typer.echo("Model preparation failed", err=True)
raise typer.Exit(code=1)
ds = trainer.load_and_format_dataset(
dataset_source=dataset_val or "",
format_type=merged["format_type"],
local_datasets=local_dataset_val,
dataset_source=cfg.data.dataset or "",
format_type=cfg.data.format_type,
local_datasets=cfg.data.local_dataset,
)
if ds is None:
typer.echo("Dataset load failed", err=True)
raise typer.Exit(code=1)
started = trainer.start_training(
dataset=ds,
output_dir=str(output_dir_val),
num_epochs=merged["num_epochs"],
learning_rate=merged["learning_rate"],
batch_size=merged["batch_size"],
gradient_accumulation_steps=merged["gradient_accumulation_steps"],
warmup_steps=merged["warmup_steps"],
max_steps=merged["max_steps"],
save_steps=merged["save_steps"],
weight_decay=merged["weight_decay"],
random_seed=merged["random_seed"],
packing=merged["packing"],
train_on_completions=merged["train_on_completions"],
enable_wandb=merged["enable_wandb"],
wandb_project=merged["wandb_project"],
wandb_token=merged.get("wandb_token"),
enable_tensorboard=merged["enable_tensorboard"],
tensorboard_dir=merged["tensorboard_dir"],
max_seq_length=merged["max_seq_length"],
)
started = trainer.start_training(dataset=ds, **cfg.training_kwargs())
if not started:
typer.echo("Training failed to start", err=True)