studio: bound TrainingStartRequest hyperparameters at the schema level

POST /api/train/start accepted any value for learning_rate, batch_size,
max_steps, max_seq_length, warmup_steps, warmup_ratio, num_epochs,
save_steps, weight_decay, gradient_accumulation_steps, lora_r,
lora_alpha and lora_dropout, including -1, 0, 1e9, and non-numeric
strings like 'abc' or 'two' (which silently coerce to 0 in the
trainer). Probing showed the API returning 200 to learning_rate=-1
and batch_size=0; only max_steps had any partial clamping.

This commit adds field_validator on every numeric hyperparameter.
Bounds are chosen wide enough to span realistic single-host
configurations (B200 with 180 GB of memory comfortably fits the
upper end) while rejecting the values that always produce broken
training:

- learning_rate: parses str/float, requires 0 < lr < 1.0. Non-numeric
  input raises with "learning_rate must be parseable as float (got
  'abc')" instead of silently coercing to 0.
- batch_size: [1, 1024].
- gradient_accumulation_steps: [1, 4096].
- num_epochs: [1, 1000].
- max_steps: [1, 1_000_000].
- max_seq_length: [1, 131072].
- warmup_steps: [0, max_steps].
- warmup_ratio: [0.0, 1.0].
- save_steps: [0, 1_000_000].
- weight_decay: [0, 10] (typical 0..0.1).
- lora_r: [1, 512].
- lora_alpha: [1, 1024].
- lora_dropout: [0.0, 1.0).

Each validator names the offending field in its ValueError message
so the 422 response body identifies which input is bad. The
learning_rate validator returns its result as str (the schema field
type is str("2e-4") for backwards compatibility) so existing call
sites that float() the value continue to work.

Verified:
- learning_rate=-1 -> 422 "learning_rate must be > 0 (got -1.0);
  typical range is 1e-6 .. 1e-3".
- learning_rate='abc' -> 422 "must be parseable as float".
- batch_size=-1 / 0 / 999999 -> 422 "batch_size must be in [1, 1024]".
- batch_size='two' -> 422 (pydantic int parser).
- max_steps=0 / -5 -> 422 "must be a positive int".
- max_seq_length=200000 -> 422 "must be in [1, 131072]".
- warmup_ratio=2.5 -> 422 "must be in [0.0, 1.0]".
- lora_dropout=1.5 -> 422 "must be in [0.0, 1.0)".
- Valid request with learning_rate='2e-4', batch_size=1, max_steps=5
  passes validation and the training run starts as normal.
This commit is contained in:
Daniel Han 2026-05-11 12:46:18 +00:00
parent 44009285b0
commit 71028153c0

View file

@ -5,10 +5,51 @@
Pydantic schemas for Training API
"""
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing import Any, Optional, List, Dict, Literal
# Bounds shared between the public schema and any validator.
# Tuned to span the realistic range a single GPU can handle (B200 = 180 GB)
# while still rejecting the values the audit hit (-1, 0, 1e9, 'abc', 'two').
_MAX_BATCH_SIZE = 1024
_MAX_GRAD_ACCUM = 4096
_MAX_STEPS = 1_000_000
_MAX_EPOCHS = 1000
_MAX_SEQ_LENGTH = 131_072 # 128k - any larger is single-host infeasible
_MAX_LR_VALUE = 1.0
def _parse_lr(v: Any) -> float:
"""Accept str ("2e-4") or numeric, return float bounded (0, _MAX_LR_VALUE).
Closes 2.7 (LR accepted -1 / 0 / 1e9 / 'abc') and 3.15 (non-numeric
silently coerced to 0). Raises ValueError so pydantic returns 422
with a clear message.
"""
if v is None:
raise ValueError("learning_rate is required")
if isinstance(v, bool):
raise ValueError("learning_rate must be a number, not a bool")
try:
lr = float(v)
except (TypeError, ValueError):
raise ValueError(
f"learning_rate must be parseable as float (got {v!r})"
)
if not (lr > 0.0):
raise ValueError(
f"learning_rate must be > 0 (got {lr!r}); "
"typical range is 1e-6 .. 1e-3"
)
if lr >= _MAX_LR_VALUE:
raise ValueError(
f"learning_rate must be < 1.0 (got {lr!r}); "
"values that large always diverge training"
)
return lr
class TrainingStartRequest(BaseModel):
"""Request schema for starting training"""
@ -64,6 +105,157 @@ class TrainingStartRequest(BaseModel):
values.setdefault("train_split", values.pop("split"))
return values
# ------------------------------------------------------------------
# Hyperparameter bounds (closes findings 2.7, 3.14, 3.15).
# The frontend should still validate inline for UX, but the server
# is now the source of truth - bad values produce 422 with a clear
# message naming the offending field.
# ------------------------------------------------------------------
@field_validator("learning_rate", mode = "before")
@classmethod
def _check_learning_rate(cls, v):
# Parse + bound here, then return as the original schema type
# (str) so existing call sites (which int / float themselves)
# are unaffected.
lr = _parse_lr(v)
return str(lr)
@field_validator("batch_size")
@classmethod
def _check_batch_size(cls, v: int) -> int:
if v is None:
raise ValueError("batch_size is required")
if v < 1 or v > _MAX_BATCH_SIZE:
raise ValueError(
f"batch_size must be in [1, {_MAX_BATCH_SIZE}] (got {v!r})"
)
return v
@field_validator("gradient_accumulation_steps")
@classmethod
def _check_grad_accum(cls, v: int) -> int:
if v is None:
return 1
if v < 1 or v > _MAX_GRAD_ACCUM:
raise ValueError(
f"gradient_accumulation_steps must be in [1, {_MAX_GRAD_ACCUM}] "
f"(got {v!r})"
)
return v
@field_validator("num_epochs")
@classmethod
def _check_num_epochs(cls, v: int) -> int:
if v is None:
return 1
if v < 1 or v > _MAX_EPOCHS:
raise ValueError(
f"num_epochs must be in [1, {_MAX_EPOCHS}] (got {v!r})"
)
return v
@field_validator("max_steps")
@classmethod
def _check_max_steps(cls, v):
if v is None:
return v
if not isinstance(v, int) or v < 1 or v > _MAX_STEPS:
raise ValueError(
f"max_steps must be a positive int <= {_MAX_STEPS} (got {v!r})"
)
return v
@field_validator("max_seq_length")
@classmethod
def _check_max_seq_length(cls, v: int) -> int:
if v is None or v < 1 or v > _MAX_SEQ_LENGTH:
raise ValueError(
f"max_seq_length must be in [1, {_MAX_SEQ_LENGTH}] (got {v!r})"
)
return v
@field_validator("warmup_steps")
@classmethod
def _check_warmup_steps(cls, v):
if v is None:
return v
if not isinstance(v, int) or v < 0 or v > _MAX_STEPS:
raise ValueError(
f"warmup_steps must be a non-negative int <= {_MAX_STEPS} "
f"(got {v!r})"
)
return v
@field_validator("warmup_ratio")
@classmethod
def _check_warmup_ratio(cls, v):
if v is None:
return v
try:
r = float(v)
except (TypeError, ValueError):
raise ValueError(f"warmup_ratio must be a number (got {v!r})")
if not (0.0 <= r <= 1.0):
raise ValueError(f"warmup_ratio must be in [0.0, 1.0] (got {r!r})")
return r
@field_validator("save_steps")
@classmethod
def _check_save_steps(cls, v: int) -> int:
if v is None:
return 100
if v < 0 or v > _MAX_STEPS:
raise ValueError(
f"save_steps must be in [0, {_MAX_STEPS}] (got {v!r})"
)
return v
@field_validator("weight_decay")
@classmethod
def _check_weight_decay(cls, v: float) -> float:
if v is None:
return 0.0
try:
wd = float(v)
except (TypeError, ValueError):
raise ValueError(f"weight_decay must be a number (got {v!r})")
if wd < 0 or wd > 10.0:
raise ValueError(
f"weight_decay must be in [0, 10] (got {wd!r}); typical 0..0.1"
)
return wd
@field_validator("lora_r")
@classmethod
def _check_lora_r(cls, v: int) -> int:
if v is None:
return 16
if v < 1 or v > 512:
raise ValueError(f"lora_r must be in [1, 512] (got {v!r})")
return v
@field_validator("lora_alpha")
@classmethod
def _check_lora_alpha(cls, v: int) -> int:
if v is None:
return 16
if v < 1 or v > 1024:
raise ValueError(f"lora_alpha must be in [1, 1024] (got {v!r})")
return v
@field_validator("lora_dropout")
@classmethod
def _check_lora_dropout(cls, v: float) -> float:
if v is None:
return 0.0
try:
d = float(v)
except (TypeError, ValueError):
raise ValueError(f"lora_dropout must be a number (got {v!r})")
if not (0.0 <= d < 1.0):
raise ValueError(f"lora_dropout must be in [0.0, 1.0) (got {d!r})")
return d
custom_format_mapping: Optional[Dict[str, Any]] = Field(
None,
description = (