# Copyright (c) Meta Platforms, Inc. and affiliates. import logging import math from functools import partial from pydantic import BaseModel, ConfigDict from torch import nn from torch.optim import AdamW, lr_scheduler logger = logging.getLogger() class OptimArgs(BaseModel): model_config = ConfigDict(extra="forbid") lr: float = 3e-4 weight_decay: float = 0.1 epsilon: float = 1e-8 beta1: float = 0.9 beta2: float = 0.95 clip: float = 1.0 scheduler: str = "cosine" warmup: int = 2000 lr_min_ratio: float = 0.1 cycle_length: float = 1.0 cosine_theta: float = 1.0 annealing_step: int = 1000 decay_fraction: float = 0.1 exp_factor: float = 0.5 def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float: if step < warmup: lr = float(step) / warmup elif step <= n_steps: s = float(step - warmup) / (n_steps - warmup) lr = s * min_ratio + (1 - s) else: lr = min_ratio return lr def lr_inv_sqrt(step: int, warmup: int, exp_factor: float, min_ratio: float) -> float: if step < warmup: lr = float(step) / warmup else: lr = max((warmup**exp_factor) / (step**exp_factor), min_ratio) return lr def lr_cosine( step: int, warmup: int, n_steps: int, cycle_length: float, theta: float, min_ratio: float, ) -> float: sign = ((step // (n_steps * cycle_length)) % 2) * -2 + 1 if step < warmup: lr = float(step) / warmup elif step <= n_steps: s = float(step - warmup) / (n_steps - warmup) lr = min_ratio + 0.5 * (1 - min_ratio) * ( sign * math.cos(math.pi * s**theta / cycle_length) + 1 ) else: lr = min_ratio return lr def lr_wsd( step: int, warmup: int, n_steps: int, decay_fraction: float, cycle_length: float, min_ratio: float, ) -> float: """ UNDERSTANDING WARMUP-STABLE-DECAY LEARNING RATES: A RIVER VALLEY LOSS LANDSCAPE PERSPECTIVE https://arxiv.org/pdf/2410.05192 """ cycle_num = step // int(n_steps * cycle_length) + 1 curr_n_steps = int(n_steps * cycle_length) * cycle_num decay_length = int(curr_n_steps * decay_fraction) if step < warmup: lr = float(step) / warmup elif step <= curr_n_steps - decay_length: lr = 1.0 elif step > curr_n_steps - decay_length and step <= curr_n_steps: # Linear interpolation gives similar results # slope = -(1.0 - min_ratio) / decay_length # intercept = min_ratio + ((1.0 - min_ratio) * curr_n_steps) / decay_length # lr = slope * step + intercept step = step - (curr_n_steps - decay_length) lr = 1 / ((step / curr_n_steps) * (1 / min_ratio) + (1 - step / curr_n_steps)) else: lr = min_ratio return lr def build_lr_fn(args: OptimArgs, n_steps: int): if args.scheduler == "constant": lr_fn = lambda x: 1.0 elif args.scheduler == "linear": lr_fn = partial( lr_linear, warmup=args.warmup, n_steps=n_steps, min_ratio=args.lr_min_ratio ) elif args.scheduler == "inv_sqrt": lr_fn = partial( lr_inv_sqrt, warmup=args.warmup, exp_factor=args.exp_factor, min_ratio=args.lr_min_ratio, ) elif args.scheduler == "cosine": lr_fn = partial( lr_cosine, warmup=args.warmup, n_steps=n_steps, cycle_length=args.cycle_length, theta=args.cosine_theta, min_ratio=args.lr_min_ratio, ) elif args.scheduler == "wsd": assert args.decay_fraction < args.cycle_length lr_fn = partial( lr_wsd, warmup=args.warmup, n_steps=n_steps, decay_fraction=args.decay_fraction, cycle_length=args.cycle_length, min_ratio=args.lr_min_ratio, ) else: raise NotImplementedError(f"Unknown scheduler: {args.scheduler}") return lr_fn def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int): logger.info("Starting build of optimizer...") optimizer = AdamW( model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, eps=args.epsilon, fused=True, # Faster optim.step but can throw errors ) # scheduler lr_fn = build_lr_fn(args, n_steps) scheduler = lr_scheduler.LambdaLR(optimizer, lr_fn) logger.info("Done with build of optimizer.") return optimizer, scheduler