mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
fix(studio): correct default weight_decay and learning rate (#4695)
* fix(studio): change default weight_decay from 0.01 to 0.001
The default weight decay across Studio was 0.01 but should be 0.001.
Updated the default in all backend fallbacks, the Pydantic model, the
frontend config, and every YAML preset/model-default config.
* fix(studio): auto-set learning rate based on training method
Default LR should be 2e-4 for LoRA/QLoRA and 2e-5 for full fine-tuning.
Frontend: track whether the user has manually edited the LR field via a
_learningRateManuallySet flag (same pattern as trainOnCompletions).
When switching training method and the user has not touched the LR,
auto-set it to the appropriate default. Reset the flag on model load.
Backend: change trainer.py start_training default from 5e-5 to 2e-4,
update default.yaml fallback from 5e-5 to 2e-4, and fix
full_finetune.yaml from 0.0002 (2e-4) to 2e-5.
* refactor(studio): centralize weight_decay and learning rate defaults
Create studio/backend/core/training/constants.py as the single source of
truth for DEFAULT_WEIGHT_DECAY (0.001), DEFAULT_LEARNING_RATE (2e-4),
DEFAULT_LEARNING_RATE_FULL (2e-5), and DEFAULT_LEARNING_RATE_STR ("2e-4").
All backend modules (trainer.py, training.py, worker.py, models/training.py)
now import from constants.py instead of hardcoding values.
On the frontend, add LR_DEFAULT_LORA and LR_DEFAULT_FULL to
config/training.ts and use them in the store instead of magic numbers.
A comment cross-references the backend constants file.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Fix model-specific LR override, persist migration, and flag resets
- Preserve model-specific learning rates from YAML configs when the
async autoSelectTrainingMethod callback fires (fixes Qwen2.5-1.5B
getting 2e-4 instead of its configured 1e-5, etc.)
- Bump zustand persist version to 9 with migration so existing users
with weightDecay=0.01 get updated to 0.001
- Clear _learningRateManuallySet in reset() and applyConfigPatch()
for consistency with trainOnCompletions flag behavior
- Add DEFAULT_LEARNING_RATE_FULL_STR to constants.py
* Refine applyConfigPatch to only clear LR flag when patch includes LR
Only reset _learningRateManuallySet when the applied config patch
actually provides a learningRate value. This prevents unrelated config
patches from silently disarming the manual-edit guard, which would
cause a subsequent setTrainingMethod call to overwrite the user's
custom LR.
* Preserve model-specific LR when switching between qlora and lora
Only auto-switch the learning rate when the training category changes
(adapter <-> full fine-tuning). Switching between qlora and lora keeps
the current LR since both methods share the same learning rate range.
This preserves curated per-model defaults (e.g. 1e-5 for
Qwen2.5-1.5B-Instruct) when the user toggles between adapter methods.
* Remove constants.py, use YAML configs as the source of truth
The YAML config files (model-specific + default.yaml) are the intended
config layer for training defaults. The Python backend fallbacks now use
inline values that match the YAML configs, rather than importing from a
separate constants module. This keeps the config architecture simple:
YAML files are the single source of truth, and the inline Python
fallbacks are just safety nets that mirror them.
* fix(studio): preserve model-specific LR when switching training method
Stash YAML-provided learning rate and use it to restore the correct
value when switching between adapter and full fine-tune modes.
- qlora <-> lora no longer overwrites the model's LR
- full -> adapter restores the YAML LR instead of a hardcoded constant
- selecting a model while on full fine-tune uses LR_DEFAULT_FULL
instead of applying the YAML adapter LR
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
This commit is contained in:
parent
28aaf849bf
commit
e164c930ff
18 changed files with 110 additions and 31 deletions
|
|
@ -10,13 +10,13 @@ training:
|
|||
load_in_4bit: false
|
||||
output_dir: outputs
|
||||
num_epochs: 1
|
||||
learning_rate: 0.0002
|
||||
learning_rate: 2e-5
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
warmup_steps: 5
|
||||
max_steps: 0
|
||||
save_steps: 0
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: false
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ training:
|
|||
warmup_steps: 5
|
||||
max_steps: 0
|
||||
save_steps: 0
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: false
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@ training:
|
|||
max_seq_length: 2048
|
||||
# num_epochs: 4
|
||||
num_epochs: 0
|
||||
learning_rate: 5e-5
|
||||
learning_rate: 2e-4
|
||||
batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
warmup_ratio: 0.1
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: true
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ training:
|
|||
warmup_ratio: 0.03
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: false
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ training:
|
|||
warmup_ratio: 0.03
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: false
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ training:
|
|||
warmup_ratio: 0.03
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: false
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ training:
|
|||
warmup_ratio: 0.03
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: false
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ training:
|
|||
warmup_ratio: 0.03
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: false
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ training:
|
|||
warmup_steps: 5
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: true
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ training:
|
|||
warmup_steps: 5
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: true
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ training:
|
|||
warmup_steps: 0
|
||||
max_steps: 30
|
||||
save_steps: 30
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: true
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ training:
|
|||
warmup_steps: 5
|
||||
max_steps: 0
|
||||
save_steps: 0
|
||||
weight_decay: 0.01
|
||||
weight_decay: 0.001
|
||||
random_seed: 3407
|
||||
packing: false
|
||||
train_on_completions: false
|
||||
|
|
|
|||
|
|
@ -2658,14 +2658,14 @@ class UnslothTrainer:
|
|||
eval_steps: float = 0.00,
|
||||
output_dir: str | None = None,
|
||||
num_epochs: int = 3,
|
||||
learning_rate: float = 5e-5,
|
||||
learning_rate: float = 2e-4,
|
||||
batch_size: int = 2,
|
||||
gradient_accumulation_steps: int = 4,
|
||||
warmup_steps: int = None,
|
||||
warmup_ratio: float = None,
|
||||
max_steps: int = 0,
|
||||
save_steps: int = 0,
|
||||
weight_decay: float = 0.01,
|
||||
weight_decay: float = 0.001,
|
||||
random_seed: int = 3407,
|
||||
packing: bool = False,
|
||||
train_on_completions: bool = False,
|
||||
|
|
@ -3034,7 +3034,7 @@ class UnslothTrainer:
|
|||
"fp16": not is_bfloat16_supported(),
|
||||
"bf16": is_bfloat16_supported(),
|
||||
"logging_steps": 1,
|
||||
"weight_decay": training_args.get("weight_decay", 0.01),
|
||||
"weight_decay": training_args.get("weight_decay", 0.001),
|
||||
"seed": training_args.get("random_seed", 3407),
|
||||
"output_dir": output_dir,
|
||||
"report_to": _build_report_targets(training_args),
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class TrainingBackend:
|
|||
"warmup_ratio": kwargs.get("warmup_ratio"),
|
||||
"max_steps": kwargs.get("max_steps", 0),
|
||||
"save_steps": kwargs.get("save_steps", 0),
|
||||
"weight_decay": kwargs.get("weight_decay", 0.01),
|
||||
"weight_decay": kwargs.get("weight_decay", 0.001),
|
||||
"random_seed": kwargs.get("random_seed", 3407),
|
||||
"packing": kwargs.get("packing", False),
|
||||
"optim": kwargs.get("optim", "adamw_8bit"),
|
||||
|
|
|
|||
|
|
@ -795,7 +795,7 @@ def run_training_process(
|
|||
warmup_ratio = config.get("warmup_ratio"),
|
||||
max_steps = max_steps if max_steps and max_steps > 0 else 0,
|
||||
save_steps = save_steps if save_steps and save_steps > 0 else 0,
|
||||
weight_decay = config.get("weight_decay", 0.01),
|
||||
weight_decay = config.get("weight_decay", 0.001),
|
||||
random_seed = config.get("random_seed", 3407),
|
||||
packing = config.get("packing", False),
|
||||
train_on_completions = config.get("train_on_completions", False),
|
||||
|
|
@ -1141,7 +1141,7 @@ def _run_embedding_training(event_queue: Any, stop_queue: Any, config: dict) ->
|
|||
"lr_scheduler_type": config.get("lr_scheduler_type", "linear"),
|
||||
"batch_sampler": BatchSamplers.NO_DUPLICATES,
|
||||
"optim": config.get("optim", "adamw_8bit"),
|
||||
"weight_decay": config.get("weight_decay", 0.01),
|
||||
"weight_decay": config.get("weight_decay", 0.001),
|
||||
"seed": config.get("random_seed", 3407),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class TrainingStartRequest(BaseModel):
|
|||
warmup_ratio: Optional[float] = Field(None, description = "Warmup ratio")
|
||||
max_steps: Optional[int] = Field(None, description = "Maximum training steps")
|
||||
save_steps: int = Field(100, description = "Steps between checkpoints")
|
||||
weight_decay: float = Field(0.01, description = "Weight decay")
|
||||
weight_decay: float = Field(0.001, description = "Weight decay")
|
||||
random_seed: int = Field(42, description = "Random seed")
|
||||
packing: bool = Field(False, description = "Enable sequence packing")
|
||||
optim: str = Field("adamw_8bit", description = "Optimizer")
|
||||
|
|
|
|||
|
|
@ -90,10 +90,17 @@ export const LR_SCHEDULER_OPTIONS: ReadonlyArray<{ value: string; label: string
|
|||
{ value: "cosine", label: "Cosine" },
|
||||
];
|
||||
|
||||
/**
|
||||
* Method-aware learning rate defaults.
|
||||
* Backend mirrors these in the YAML configs under studio/backend/assets/configs/.
|
||||
*/
|
||||
export const LR_DEFAULT_LORA = 2e-4;
|
||||
export const LR_DEFAULT_FULL = 2e-5;
|
||||
|
||||
export const DEFAULT_HYPERPARAMS = {
|
||||
epochs: 3,
|
||||
contextLength: 2048,
|
||||
learningRate: 2e-4,
|
||||
learningRate: LR_DEFAULT_LORA,
|
||||
optimizerType: "adamw_8bit",
|
||||
lrSchedulerType: "linear",
|
||||
loraRank: 16,
|
||||
|
|
@ -102,7 +109,7 @@ export const DEFAULT_HYPERPARAMS = {
|
|||
loraVariant: "lora" as const,
|
||||
batchSize: 4,
|
||||
gradientAccumulation: 8,
|
||||
weightDecay: 0.01,
|
||||
weightDecay: 0.001,
|
||||
warmupSteps: 5,
|
||||
maxSteps: 60,
|
||||
saveSteps: 0,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
import { DEFAULT_HYPERPARAMS, STEPS } from "@/config/training";
|
||||
import { DEFAULT_HYPERPARAMS, LR_DEFAULT_FULL, LR_DEFAULT_LORA, STEPS } from "@/config/training";
|
||||
import { authFetch } from "@/features/auth";
|
||||
import { isAdapterMethod } from "@/types/training";
|
||||
import type { ModelType, StepNumber, TrainingMethod } from "@/types/training";
|
||||
import { create } from "zustand";
|
||||
import { persist } from "zustand/middleware";
|
||||
|
|
@ -98,6 +99,15 @@ let _modelConfigController: AbortController | null = null;
|
|||
// since the last auto-set (model load or dataset change).
|
||||
let _trainOnCompletionsManuallySet = false;
|
||||
|
||||
// Track whether the user has manually edited the learning rate
|
||||
// since the last model load. When false, switching training method
|
||||
// auto-sets LR to 2e-4 (LoRA/QLoRA) or 2e-5 (full fine-tune).
|
||||
let _learningRateManuallySet = false;
|
||||
|
||||
// Stash the model-config-provided (YAML) learning rate so that
|
||||
// setTrainingMethod can restore it when switching back from full to adapter.
|
||||
let _yamlLearningRate: number | undefined = undefined;
|
||||
|
||||
const NON_PERSISTED_STATE_KEYS: ReadonlySet<keyof TrainingConfigState> = new Set([
|
||||
"modelType",
|
||||
"isCheckingVision",
|
||||
|
|
@ -165,8 +175,22 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
|
|||
if (get().selectedModel !== modelName) return;
|
||||
|
||||
_trainOnCompletionsManuallySet = false;
|
||||
_learningRateManuallySet = false;
|
||||
_yamlLearningRate = undefined;
|
||||
const patch = mapBackendModelConfigToTrainingPatch(modelDetails.config);
|
||||
|
||||
// If the model config provides a specific learning rate, treat
|
||||
// it as authoritative so the async auto-select does not overwrite it.
|
||||
const modelConfigHasLR = patch.learningRate !== undefined;
|
||||
_yamlLearningRate = patch.learningRate;
|
||||
|
||||
// YAML learning rates are tuned for adapter methods (LoRA/QLoRA).
|
||||
// If the user is currently on full fine-tune, override with the
|
||||
// full-finetune default instead of applying the YAML adapter LR.
|
||||
if (modelConfigHasLR && !isAdapterMethod(get().trainingMethod)) {
|
||||
patch.learningRate = LR_DEFAULT_FULL;
|
||||
}
|
||||
|
||||
// If vision model + image dataset already known, override
|
||||
// trainOnCompletions to false regardless of backend default.
|
||||
if (modelDetails.is_vision && get().isDatasetImage === true) {
|
||||
|
|
@ -174,11 +198,11 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
|
|||
}
|
||||
|
||||
const isAudio = !!modelDetails.is_audio;
|
||||
// Pure audio model → always uncheck trainOnCompletions.
|
||||
// Pure audio model -> always uncheck trainOnCompletions.
|
||||
if (isAudio && !modelDetails.is_vision) {
|
||||
patch.trainOnCompletions = false;
|
||||
}
|
||||
// Audio-capable vision model (e.g. gemma3n) + audio dataset → uncheck.
|
||||
// Audio-capable vision model (e.g. gemma3n) + audio dataset -> uncheck.
|
||||
if (isAudio && modelDetails.is_vision && get().isDatasetAudio) {
|
||||
patch.trainOnCompletions = false;
|
||||
}
|
||||
|
|
@ -197,7 +221,12 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
|
|||
void autoSelectTrainingMethod(modelSizeBytes, patch.contextLength ?? get().contextLength)
|
||||
.then((method) => {
|
||||
if (get().selectedModel !== modelName) return;
|
||||
if (method) set({ trainingMethod: method });
|
||||
if (method) {
|
||||
const lrPatch = !_learningRateManuallySet && !modelConfigHasLR
|
||||
? { learningRate: method === "full" ? LR_DEFAULT_FULL : LR_DEFAULT_LORA }
|
||||
: {};
|
||||
set({ trainingMethod: method, ...lrPatch });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -366,7 +395,31 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
|
|||
if (state.modelDefaultsAppliedFor === state.selectedModel) return;
|
||||
void loadAndApplyModelDefaults(state.selectedModel);
|
||||
},
|
||||
setTrainingMethod: (trainingMethod) => set({ trainingMethod }),
|
||||
setTrainingMethod: (trainingMethod) => {
|
||||
if (_learningRateManuallySet) {
|
||||
set({ trainingMethod });
|
||||
return;
|
||||
}
|
||||
|
||||
const prev = get().trainingMethod;
|
||||
const wasAdapter = isAdapterMethod(prev);
|
||||
const nowAdapter = isAdapterMethod(trainingMethod);
|
||||
|
||||
// qlora <-> lora: same LR range, don't touch learning rate
|
||||
if (wasAdapter && nowAdapter) {
|
||||
set({ trainingMethod });
|
||||
return;
|
||||
}
|
||||
|
||||
// Category changed (adapter <-> full)
|
||||
if (nowAdapter) {
|
||||
// Switching TO adapter: restore YAML LR if available
|
||||
set({ trainingMethod, learningRate: _yamlLearningRate ?? LR_DEFAULT_LORA });
|
||||
} else {
|
||||
// Switching TO full: no YAML full-LR exists, use constant
|
||||
set({ trainingMethod, learningRate: LR_DEFAULT_FULL });
|
||||
}
|
||||
},
|
||||
setHfToken: (hfToken) =>
|
||||
set({ hfToken: hfToken.trim().replace(/^["']+|["']+$/g, "") }),
|
||||
setDatasetSource: (datasetSource) => set({ datasetSource }),
|
||||
|
|
@ -509,7 +562,10 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
|
|||
}),
|
||||
setEpochs: (epochs) => set({ epochs }),
|
||||
setContextLength: (contextLength) => set({ contextLength }),
|
||||
setLearningRate: (learningRate) => set({ learningRate }),
|
||||
setLearningRate: (learningRate) => {
|
||||
_learningRateManuallySet = true;
|
||||
set({ learningRate });
|
||||
},
|
||||
setOptimizerType: (optimizerType) => set({ optimizerType }),
|
||||
setLrSchedulerType: (lrSchedulerType) => set({ lrSchedulerType }),
|
||||
setLoraRank: (loraRank) => set({ loraRank }),
|
||||
|
|
@ -548,7 +604,12 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
|
|||
set({ finetuneMLPModules }),
|
||||
setTargetModules: (targetModules) => set({ targetModules }),
|
||||
canProceed: () => canProceedForStep(get()),
|
||||
reset: () => set(initialState),
|
||||
reset: () => {
|
||||
_trainOnCompletionsManuallySet = false;
|
||||
_learningRateManuallySet = false;
|
||||
_yamlLearningRate = undefined;
|
||||
set(initialState);
|
||||
},
|
||||
resetToModelDefaults: () => {
|
||||
const { selectedModel } = get();
|
||||
if (!selectedModel) return;
|
||||
|
|
@ -557,13 +618,18 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
|
|||
},
|
||||
applyConfigPatch: (config: BackendModelConfig) => {
|
||||
const patch = mapBackendModelConfigToTrainingPatch(config);
|
||||
// Only clear the manual-edit flag when the config provides a LR,
|
||||
// so unrelated config patches don't silently disarm the guard.
|
||||
if (patch.learningRate !== undefined) {
|
||||
_learningRateManuallySet = false;
|
||||
}
|
||||
set(patch);
|
||||
},
|
||||
};
|
||||
},
|
||||
{
|
||||
name: "unsloth_training_config_v1",
|
||||
version: 8,
|
||||
version: 9,
|
||||
migrate: (persisted, version) => {
|
||||
const s = persisted as Record<string, unknown>;
|
||||
if (version < 2 && s.datasetSubset == null && s.datasetConfig != null) {
|
||||
|
|
@ -593,6 +659,12 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
|
|||
s.datasetLabelMapping ??= {};
|
||||
s.datasetAdvisorNotification ??= null;
|
||||
}
|
||||
if (version < 9) {
|
||||
// weight_decay default changed from 0.01 to 0.001.
|
||||
if (s.weightDecay === 0.01) {
|
||||
s.weightDecay = DEFAULT_HYPERPARAMS.weightDecay;
|
||||
}
|
||||
}
|
||||
return s as unknown as TrainingConfigStore;
|
||||
},
|
||||
partialize: partializePersistedState,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue