mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-20 00:51:36 +00:00
studio: bound TrainingStartRequest hyperparameters at the schema level
POST /api/train/start accepted any value for learning_rate, batch_size,
max_steps, max_seq_length, warmup_steps, warmup_ratio, num_epochs,
save_steps, weight_decay, gradient_accumulation_steps, lora_r,
lora_alpha and lora_dropout, including -1, 0, 1e9, and non-numeric
strings like 'abc' or 'two' (which silently coerce to 0 in the
trainer). Probing showed the API returning 200 to learning_rate=-1
and batch_size=0; only max_steps had any partial clamping.
This commit adds field_validator on every numeric hyperparameter.
Bounds are chosen wide enough to span realistic single-host
configurations (B200 with 180 GB of memory comfortably fits the
upper end) while rejecting the values that always produce broken
training:
- learning_rate: parses str/float, requires 0 < lr < 1.0. Non-numeric
input raises with "learning_rate must be parseable as float (got
'abc')" instead of silently coercing to 0.
- batch_size: [1, 1024].
- gradient_accumulation_steps: [1, 4096].
- num_epochs: [1, 1000].
- max_steps: [1, 1_000_000].
- max_seq_length: [1, 131072].
- warmup_steps: [0, max_steps].
- warmup_ratio: [0.0, 1.0].
- save_steps: [0, 1_000_000].
- weight_decay: [0, 10] (typical 0..0.1).
- lora_r: [1, 512].
- lora_alpha: [1, 1024].
- lora_dropout: [0.0, 1.0).
Each validator names the offending field in its ValueError message
so the 422 response body identifies which input is bad. The
learning_rate validator returns its result as str (the schema field
type is str("2e-4") for backwards compatibility) so existing call
sites that float() the value continue to work.
Verified:
- learning_rate=-1 -> 422 "learning_rate must be > 0 (got -1.0);
typical range is 1e-6 .. 1e-3".
- learning_rate='abc' -> 422 "must be parseable as float".
- batch_size=-1 / 0 / 999999 -> 422 "batch_size must be in [1, 1024]".
- batch_size='two' -> 422 (pydantic int parser).
- max_steps=0 / -5 -> 422 "must be a positive int".
- max_seq_length=200000 -> 422 "must be in [1, 131072]".
- warmup_ratio=2.5 -> 422 "must be in [0.0, 1.0]".
- lora_dropout=1.5 -> 422 "must be in [0.0, 1.0)".
- Valid request with learning_rate='2e-4', batch_size=1, max_steps=5
passes validation and the training run starts as normal.
This commit is contained in:
parent
44009285b0
commit
71028153c0
1 changed files with 193 additions and 1 deletions
|
|
@ -5,10 +5,51 @@
|
|||
Pydantic schemas for Training API
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from typing import Any, Optional, List, Dict, Literal
|
||||
|
||||
|
||||
# Bounds shared between the public schema and any validator.
|
||||
# Tuned to span the realistic range a single GPU can handle (B200 = 180 GB)
|
||||
# while still rejecting the values the audit hit (-1, 0, 1e9, 'abc', 'two').
|
||||
_MAX_BATCH_SIZE = 1024
|
||||
_MAX_GRAD_ACCUM = 4096
|
||||
_MAX_STEPS = 1_000_000
|
||||
_MAX_EPOCHS = 1000
|
||||
_MAX_SEQ_LENGTH = 131_072 # 128k - any larger is single-host infeasible
|
||||
_MAX_LR_VALUE = 1.0
|
||||
|
||||
|
||||
def _parse_lr(v: Any) -> float:
|
||||
"""Accept str ("2e-4") or numeric, return float bounded (0, _MAX_LR_VALUE).
|
||||
|
||||
Closes 2.7 (LR accepted -1 / 0 / 1e9 / 'abc') and 3.15 (non-numeric
|
||||
silently coerced to 0). Raises ValueError so pydantic returns 422
|
||||
with a clear message.
|
||||
"""
|
||||
if v is None:
|
||||
raise ValueError("learning_rate is required")
|
||||
if isinstance(v, bool):
|
||||
raise ValueError("learning_rate must be a number, not a bool")
|
||||
try:
|
||||
lr = float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
f"learning_rate must be parseable as float (got {v!r})"
|
||||
)
|
||||
if not (lr > 0.0):
|
||||
raise ValueError(
|
||||
f"learning_rate must be > 0 (got {lr!r}); "
|
||||
"typical range is 1e-6 .. 1e-3"
|
||||
)
|
||||
if lr >= _MAX_LR_VALUE:
|
||||
raise ValueError(
|
||||
f"learning_rate must be < 1.0 (got {lr!r}); "
|
||||
"values that large always diverge training"
|
||||
)
|
||||
return lr
|
||||
|
||||
|
||||
class TrainingStartRequest(BaseModel):
|
||||
"""Request schema for starting training"""
|
||||
|
||||
|
|
@ -64,6 +105,157 @@ class TrainingStartRequest(BaseModel):
|
|||
values.setdefault("train_split", values.pop("split"))
|
||||
return values
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hyperparameter bounds (closes findings 2.7, 3.14, 3.15).
|
||||
# The frontend should still validate inline for UX, but the server
|
||||
# is now the source of truth - bad values produce 422 with a clear
|
||||
# message naming the offending field.
|
||||
# ------------------------------------------------------------------
|
||||
@field_validator("learning_rate", mode = "before")
|
||||
@classmethod
|
||||
def _check_learning_rate(cls, v):
|
||||
# Parse + bound here, then return as the original schema type
|
||||
# (str) so existing call sites (which int / float themselves)
|
||||
# are unaffected.
|
||||
lr = _parse_lr(v)
|
||||
return str(lr)
|
||||
|
||||
@field_validator("batch_size")
|
||||
@classmethod
|
||||
def _check_batch_size(cls, v: int) -> int:
|
||||
if v is None:
|
||||
raise ValueError("batch_size is required")
|
||||
if v < 1 or v > _MAX_BATCH_SIZE:
|
||||
raise ValueError(
|
||||
f"batch_size must be in [1, {_MAX_BATCH_SIZE}] (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("gradient_accumulation_steps")
|
||||
@classmethod
|
||||
def _check_grad_accum(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 1
|
||||
if v < 1 or v > _MAX_GRAD_ACCUM:
|
||||
raise ValueError(
|
||||
f"gradient_accumulation_steps must be in [1, {_MAX_GRAD_ACCUM}] "
|
||||
f"(got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("num_epochs")
|
||||
@classmethod
|
||||
def _check_num_epochs(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 1
|
||||
if v < 1 or v > _MAX_EPOCHS:
|
||||
raise ValueError(
|
||||
f"num_epochs must be in [1, {_MAX_EPOCHS}] (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("max_steps")
|
||||
@classmethod
|
||||
def _check_max_steps(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
if not isinstance(v, int) or v < 1 or v > _MAX_STEPS:
|
||||
raise ValueError(
|
||||
f"max_steps must be a positive int <= {_MAX_STEPS} (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("max_seq_length")
|
||||
@classmethod
|
||||
def _check_max_seq_length(cls, v: int) -> int:
|
||||
if v is None or v < 1 or v > _MAX_SEQ_LENGTH:
|
||||
raise ValueError(
|
||||
f"max_seq_length must be in [1, {_MAX_SEQ_LENGTH}] (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("warmup_steps")
|
||||
@classmethod
|
||||
def _check_warmup_steps(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
if not isinstance(v, int) or v < 0 or v > _MAX_STEPS:
|
||||
raise ValueError(
|
||||
f"warmup_steps must be a non-negative int <= {_MAX_STEPS} "
|
||||
f"(got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("warmup_ratio")
|
||||
@classmethod
|
||||
def _check_warmup_ratio(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
r = float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"warmup_ratio must be a number (got {v!r})")
|
||||
if not (0.0 <= r <= 1.0):
|
||||
raise ValueError(f"warmup_ratio must be in [0.0, 1.0] (got {r!r})")
|
||||
return r
|
||||
|
||||
@field_validator("save_steps")
|
||||
@classmethod
|
||||
def _check_save_steps(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 100
|
||||
if v < 0 or v > _MAX_STEPS:
|
||||
raise ValueError(
|
||||
f"save_steps must be in [0, {_MAX_STEPS}] (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("weight_decay")
|
||||
@classmethod
|
||||
def _check_weight_decay(cls, v: float) -> float:
|
||||
if v is None:
|
||||
return 0.0
|
||||
try:
|
||||
wd = float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"weight_decay must be a number (got {v!r})")
|
||||
if wd < 0 or wd > 10.0:
|
||||
raise ValueError(
|
||||
f"weight_decay must be in [0, 10] (got {wd!r}); typical 0..0.1"
|
||||
)
|
||||
return wd
|
||||
|
||||
@field_validator("lora_r")
|
||||
@classmethod
|
||||
def _check_lora_r(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 16
|
||||
if v < 1 or v > 512:
|
||||
raise ValueError(f"lora_r must be in [1, 512] (got {v!r})")
|
||||
return v
|
||||
|
||||
@field_validator("lora_alpha")
|
||||
@classmethod
|
||||
def _check_lora_alpha(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 16
|
||||
if v < 1 or v > 1024:
|
||||
raise ValueError(f"lora_alpha must be in [1, 1024] (got {v!r})")
|
||||
return v
|
||||
|
||||
@field_validator("lora_dropout")
|
||||
@classmethod
|
||||
def _check_lora_dropout(cls, v: float) -> float:
|
||||
if v is None:
|
||||
return 0.0
|
||||
try:
|
||||
d = float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"lora_dropout must be a number (got {v!r})")
|
||||
if not (0.0 <= d < 1.0):
|
||||
raise ValueError(f"lora_dropout must be in [0.0, 1.0) (got {d!r})")
|
||||
return d
|
||||
|
||||
custom_format_mapping: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description = (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue