fix(studio): correct default weight_decay and learning rate (#4695)

* fix(studio): change default weight_decay from 0.01 to 0.001

The default weight decay across Studio was 0.01 but should be 0.001.
Updated the default in all backend fallbacks, the Pydantic model, the
frontend config, and every YAML preset/model-default config.

* fix(studio): auto-set learning rate based on training method

Default LR should be 2e-4 for LoRA/QLoRA and 2e-5 for full fine-tuning.

Frontend: track whether the user has manually edited the LR field via a
_learningRateManuallySet flag (same pattern as trainOnCompletions).
When switching training method and the user has not touched the LR,
auto-set it to the appropriate default. Reset the flag on model load.

Backend: change trainer.py start_training default from 5e-5 to 2e-4,
update default.yaml fallback from 5e-5 to 2e-4, and fix
full_finetune.yaml from 0.0002 (2e-4) to 2e-5.

* refactor(studio): centralize weight_decay and learning rate defaults

Create studio/backend/core/training/constants.py as the single source of
truth for DEFAULT_WEIGHT_DECAY (0.001), DEFAULT_LEARNING_RATE (2e-4),
DEFAULT_LEARNING_RATE_FULL (2e-5), and DEFAULT_LEARNING_RATE_STR ("2e-4").

All backend modules (trainer.py, training.py, worker.py, models/training.py)
now import from constants.py instead of hardcoding values.

On the frontend, add LR_DEFAULT_LORA and LR_DEFAULT_FULL to
config/training.ts and use them in the store instead of magic numbers.
A comment cross-references the backend constants file.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix model-specific LR override, persist migration, and flag resets

- Preserve model-specific learning rates from YAML configs when the
  async autoSelectTrainingMethod callback fires (fixes Qwen2.5-1.5B
  getting 2e-4 instead of its configured 1e-5, etc.)
- Bump zustand persist version to 9 with migration so existing users
  with weightDecay=0.01 get updated to 0.001
- Clear _learningRateManuallySet in reset() and applyConfigPatch()
  for consistency with trainOnCompletions flag behavior
- Add DEFAULT_LEARNING_RATE_FULL_STR to constants.py

* Refine applyConfigPatch to only clear LR flag when patch includes LR

Only reset _learningRateManuallySet when the applied config patch
actually provides a learningRate value. This prevents unrelated config
patches from silently disarming the manual-edit guard, which would
cause a subsequent setTrainingMethod call to overwrite the user's
custom LR.

* Preserve model-specific LR when switching between qlora and lora

Only auto-switch the learning rate when the training category changes
(adapter <-> full fine-tuning). Switching between qlora and lora keeps
the current LR since both methods share the same learning rate range.
This preserves curated per-model defaults (e.g. 1e-5 for
Qwen2.5-1.5B-Instruct) when the user toggles between adapter methods.

* Remove constants.py, use YAML configs as the source of truth

The YAML config files (model-specific + default.yaml) are the intended
config layer for training defaults. The Python backend fallbacks now use
inline values that match the YAML configs, rather than importing from a
separate constants module. This keeps the config architecture simple:
YAML files are the single source of truth, and the inline Python
fallbacks are just safety nets that mirror them.

* fix(studio): preserve model-specific LR when switching training method

Stash YAML-provided learning rate and use it to restore the correct
value when switching between adapter and full fine-tune modes.

- qlora <-> lora no longer overwrites the model's LR
- full -> adapter restores the YAML LR instead of a hardcoded constant
- selecting a model while on full fine-tune uses LR_DEFAULT_FULL
  instead of applying the YAML adapter LR

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
This commit is contained in:
Daniel Han 2026-03-31 02:50:25 -07:00 committed by GitHub
parent 28aaf849bf
commit e164c930ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 110 additions and 31 deletions

View file

@ -10,13 +10,13 @@ training:
load_in_4bit: false
output_dir: outputs
num_epochs: 1
learning_rate: 0.0002
learning_rate: 2e-5
batch_size: 1
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 0
save_steps: 0
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -16,7 +16,7 @@ training:
warmup_steps: 5
max_steps: 0
save_steps: 0
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -6,13 +6,13 @@ training:
max_seq_length: 2048
# num_epochs: 4
num_epochs: 0
learning_rate: 5e-5
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_ratio: 0.1
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true

View file

@ -12,7 +12,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -11,7 +11,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -11,7 +11,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -11,7 +11,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -11,7 +11,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -13,7 +13,7 @@ training:
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true

View file

@ -13,7 +13,7 @@ training:
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true

View file

@ -13,7 +13,7 @@ training:
warmup_steps: 0
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true

View file

@ -16,7 +16,7 @@ training:
warmup_steps: 5
max_steps: 0
save_steps: 0
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -2658,14 +2658,14 @@ class UnslothTrainer:
eval_steps: float = 0.00,
output_dir: str | None = None,
num_epochs: int = 3,
learning_rate: float = 5e-5,
learning_rate: float = 2e-4,
batch_size: int = 2,
gradient_accumulation_steps: int = 4,
warmup_steps: int = None,
warmup_ratio: float = None,
max_steps: int = 0,
save_steps: int = 0,
weight_decay: float = 0.01,
weight_decay: float = 0.001,
random_seed: int = 3407,
packing: bool = False,
train_on_completions: bool = False,
@ -3034,7 +3034,7 @@ class UnslothTrainer:
"fp16": not is_bfloat16_supported(),
"bf16": is_bfloat16_supported(),
"logging_steps": 1,
"weight_decay": training_args.get("weight_decay", 0.01),
"weight_decay": training_args.get("weight_decay", 0.001),
"seed": training_args.get("random_seed", 3407),
"output_dir": output_dir,
"report_to": _build_report_targets(training_args),

View file

@ -160,7 +160,7 @@ class TrainingBackend:
"warmup_ratio": kwargs.get("warmup_ratio"),
"max_steps": kwargs.get("max_steps", 0),
"save_steps": kwargs.get("save_steps", 0),
"weight_decay": kwargs.get("weight_decay", 0.01),
"weight_decay": kwargs.get("weight_decay", 0.001),
"random_seed": kwargs.get("random_seed", 3407),
"packing": kwargs.get("packing", False),
"optim": kwargs.get("optim", "adamw_8bit"),

View file

@ -795,7 +795,7 @@ def run_training_process(
warmup_ratio = config.get("warmup_ratio"),
max_steps = max_steps if max_steps and max_steps > 0 else 0,
save_steps = save_steps if save_steps and save_steps > 0 else 0,
weight_decay = config.get("weight_decay", 0.01),
weight_decay = config.get("weight_decay", 0.001),
random_seed = config.get("random_seed", 3407),
packing = config.get("packing", False),
train_on_completions = config.get("train_on_completions", False),
@ -1141,7 +1141,7 @@ def _run_embedding_training(event_queue: Any, stop_queue: Any, config: dict) ->
"lr_scheduler_type": config.get("lr_scheduler_type", "linear"),
"batch_sampler": BatchSamplers.NO_DUPLICATES,
"optim": config.get("optim", "adamw_8bit"),
"weight_decay": config.get("weight_decay", 0.01),
"weight_decay": config.get("weight_decay", 0.001),
"seed": config.get("random_seed", 3407),
}

View file

@ -81,7 +81,7 @@ class TrainingStartRequest(BaseModel):
warmup_ratio: Optional[float] = Field(None, description = "Warmup ratio")
max_steps: Optional[int] = Field(None, description = "Maximum training steps")
save_steps: int = Field(100, description = "Steps between checkpoints")
weight_decay: float = Field(0.01, description = "Weight decay")
weight_decay: float = Field(0.001, description = "Weight decay")
random_seed: int = Field(42, description = "Random seed")
packing: bool = Field(False, description = "Enable sequence packing")
optim: str = Field("adamw_8bit", description = "Optimizer")

View file

@ -90,10 +90,17 @@ export const LR_SCHEDULER_OPTIONS: ReadonlyArray<{ value: string; label: string
{ value: "cosine", label: "Cosine" },
];
/**
* Method-aware learning rate defaults.
* Backend mirrors these in the YAML configs under studio/backend/assets/configs/.
*/
export const LR_DEFAULT_LORA = 2e-4;
export const LR_DEFAULT_FULL = 2e-5;
export const DEFAULT_HYPERPARAMS = {
epochs: 3,
contextLength: 2048,
learningRate: 2e-4,
learningRate: LR_DEFAULT_LORA,
optimizerType: "adamw_8bit",
lrSchedulerType: "linear",
loraRank: 16,
@ -102,7 +109,7 @@ export const DEFAULT_HYPERPARAMS = {
loraVariant: "lora" as const,
batchSize: 4,
gradientAccumulation: 8,
weightDecay: 0.01,
weightDecay: 0.001,
warmupSteps: 5,
maxSteps: 60,
saveSteps: 0,

View file

@ -1,8 +1,9 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
import { DEFAULT_HYPERPARAMS, STEPS } from "@/config/training";
import { DEFAULT_HYPERPARAMS, LR_DEFAULT_FULL, LR_DEFAULT_LORA, STEPS } from "@/config/training";
import { authFetch } from "@/features/auth";
import { isAdapterMethod } from "@/types/training";
import type { ModelType, StepNumber, TrainingMethod } from "@/types/training";
import { create } from "zustand";
import { persist } from "zustand/middleware";
@ -98,6 +99,15 @@ let _modelConfigController: AbortController | null = null;
// since the last auto-set (model load or dataset change).
let _trainOnCompletionsManuallySet = false;
// Track whether the user has manually edited the learning rate
// since the last model load. When false, switching training method
// auto-sets LR to 2e-4 (LoRA/QLoRA) or 2e-5 (full fine-tune).
let _learningRateManuallySet = false;
// Stash the model-config-provided (YAML) learning rate so that
// setTrainingMethod can restore it when switching back from full to adapter.
let _yamlLearningRate: number | undefined = undefined;
const NON_PERSISTED_STATE_KEYS: ReadonlySet<keyof TrainingConfigState> = new Set([
"modelType",
"isCheckingVision",
@ -165,8 +175,22 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
if (get().selectedModel !== modelName) return;
_trainOnCompletionsManuallySet = false;
_learningRateManuallySet = false;
_yamlLearningRate = undefined;
const patch = mapBackendModelConfigToTrainingPatch(modelDetails.config);
// If the model config provides a specific learning rate, treat
// it as authoritative so the async auto-select does not overwrite it.
const modelConfigHasLR = patch.learningRate !== undefined;
_yamlLearningRate = patch.learningRate;
// YAML learning rates are tuned for adapter methods (LoRA/QLoRA).
// If the user is currently on full fine-tune, override with the
// full-finetune default instead of applying the YAML adapter LR.
if (modelConfigHasLR && !isAdapterMethod(get().trainingMethod)) {
patch.learningRate = LR_DEFAULT_FULL;
}
// If vision model + image dataset already known, override
// trainOnCompletions to false regardless of backend default.
if (modelDetails.is_vision && get().isDatasetImage === true) {
@ -174,11 +198,11 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
}
const isAudio = !!modelDetails.is_audio;
// Pure audio model always uncheck trainOnCompletions.
// Pure audio model -> always uncheck trainOnCompletions.
if (isAudio && !modelDetails.is_vision) {
patch.trainOnCompletions = false;
}
// Audio-capable vision model (e.g. gemma3n) + audio dataset uncheck.
// Audio-capable vision model (e.g. gemma3n) + audio dataset -> uncheck.
if (isAudio && modelDetails.is_vision && get().isDatasetAudio) {
patch.trainOnCompletions = false;
}
@ -197,7 +221,12 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
void autoSelectTrainingMethod(modelSizeBytes, patch.contextLength ?? get().contextLength)
.then((method) => {
if (get().selectedModel !== modelName) return;
if (method) set({ trainingMethod: method });
if (method) {
const lrPatch = !_learningRateManuallySet && !modelConfigHasLR
? { learningRate: method === "full" ? LR_DEFAULT_FULL : LR_DEFAULT_LORA }
: {};
set({ trainingMethod: method, ...lrPatch });
}
});
}
@ -366,7 +395,31 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
if (state.modelDefaultsAppliedFor === state.selectedModel) return;
void loadAndApplyModelDefaults(state.selectedModel);
},
setTrainingMethod: (trainingMethod) => set({ trainingMethod }),
setTrainingMethod: (trainingMethod) => {
if (_learningRateManuallySet) {
set({ trainingMethod });
return;
}
const prev = get().trainingMethod;
const wasAdapter = isAdapterMethod(prev);
const nowAdapter = isAdapterMethod(trainingMethod);
// qlora <-> lora: same LR range, don't touch learning rate
if (wasAdapter && nowAdapter) {
set({ trainingMethod });
return;
}
// Category changed (adapter <-> full)
if (nowAdapter) {
// Switching TO adapter: restore YAML LR if available
set({ trainingMethod, learningRate: _yamlLearningRate ?? LR_DEFAULT_LORA });
} else {
// Switching TO full: no YAML full-LR exists, use constant
set({ trainingMethod, learningRate: LR_DEFAULT_FULL });
}
},
setHfToken: (hfToken) =>
set({ hfToken: hfToken.trim().replace(/^["']+|["']+$/g, "") }),
setDatasetSource: (datasetSource) => set({ datasetSource }),
@ -509,7 +562,10 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
}),
setEpochs: (epochs) => set({ epochs }),
setContextLength: (contextLength) => set({ contextLength }),
setLearningRate: (learningRate) => set({ learningRate }),
setLearningRate: (learningRate) => {
_learningRateManuallySet = true;
set({ learningRate });
},
setOptimizerType: (optimizerType) => set({ optimizerType }),
setLrSchedulerType: (lrSchedulerType) => set({ lrSchedulerType }),
setLoraRank: (loraRank) => set({ loraRank }),
@ -548,7 +604,12 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
set({ finetuneMLPModules }),
setTargetModules: (targetModules) => set({ targetModules }),
canProceed: () => canProceedForStep(get()),
reset: () => set(initialState),
reset: () => {
_trainOnCompletionsManuallySet = false;
_learningRateManuallySet = false;
_yamlLearningRate = undefined;
set(initialState);
},
resetToModelDefaults: () => {
const { selectedModel } = get();
if (!selectedModel) return;
@ -557,13 +618,18 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
},
applyConfigPatch: (config: BackendModelConfig) => {
const patch = mapBackendModelConfigToTrainingPatch(config);
// Only clear the manual-edit flag when the config provides a LR,
// so unrelated config patches don't silently disarm the guard.
if (patch.learningRate !== undefined) {
_learningRateManuallySet = false;
}
set(patch);
},
};
},
{
name: "unsloth_training_config_v1",
version: 8,
version: 9,
migrate: (persisted, version) => {
const s = persisted as Record<string, unknown>;
if (version < 2 && s.datasetSubset == null && s.datasetConfig != null) {
@ -593,6 +659,12 @@ export const useTrainingConfigStore = create<TrainingConfigStore>()(
s.datasetLabelMapping ??= {};
s.datasetAdvisorNotification ??= null;
}
if (version < 9) {
// weight_decay default changed from 0.01 to 0.001.
if (s.weightDecay === 0.01) {
s.weightDecay = DEFAULT_HYPERPARAMS.weightDecay;
}
}
return s as unknown as TrainingConfigStore;
},
partialize: partializePersistedState,