From 8c61ab5e67ab044cd04176e03d45ef6845b62302 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 13 Feb 2025 11:58:23 -0800 Subject: [PATCH 1/4] Fix multiprocessing dataloader checkpointing and use it in the train script (#50) --- bytelatent/args.py | 2 - .../data/iterators/abstract_iterator.py | 10 ++++ bytelatent/data/iterators/arrow_iterator.py | 15 +++-- .../data/iterators/multiprocess_iterator.py | 27 ++++++--- bytelatent/train.py | 56 ++++++++++++------- 5 files changed, 77 insertions(+), 33 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index 263e8e3..47bd0f9 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,10 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import json import logging import os from typing import Any -import fsspec import numpy as np import yaml from omegaconf import OmegaConf diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 7fb442b..8ac7f19 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -21,3 +21,13 @@ class IteratorState(Generic[C]): @abc.abstractmethod def build(self) -> StatefulIterator[T, C]: pass + + +def get_state_and_refresh(iterator: StatefulIterator): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + state = iterator.get_state() + data_loader = state.build() + py_iterator = data_loader.create_iter() + return state, data_loader, py_iterator diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..995cd02 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -60,6 +60,13 @@ def shard_sort_key(file: str): return shard_number +def maybe_truncate_string(text: str, max_length: int): + if len(text) <= max_length: + return text + else: + return text[:max_length] + "..." + + class ArrowFileIterator(StatefulIterator): def __init__( self, @@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator): yield out def _set_row_num(self, target_row_num: int): - logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" - ) + data_str = maybe_truncate_string(str(self.dataset_files), 200) + logger.info(f"Setting arrow position to {target_row_num} for {data_str}") if target_row_num is None or target_row_num == 0: self.row_num = 0 self.dataset = None @@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator): else: curr_remaining -= len(batch) self.row_num = target_row_num + data_str = maybe_truncate_string(str(self.dataset_files), 200) logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {data_str}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index 0ee87df..3669167 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -26,6 +26,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -35,7 +36,6 @@ from bytelatent.distributed import ( check_model_value_range, clean_env, dist_mean, - dist_mean_dict, dist_sum, get_device_mesh, get_is_master, @@ -88,6 +88,13 @@ def get_iterator_state_name(iterator_state): raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float: + if isinstance(num, (torch.Tensor, np.ndarray)): + return num.item() + else: + return num + + # TODO: Make this pydantic based instead of data class based # TODO: Generalize this to any iterator state @dataclass @@ -603,20 +610,20 @@ def train(args: TrainArgs): # step: Metric at a step # interval: Metric averaged/summed across all steps since the last log interval. # Typically, this is 10 - step_loss_per_gpu = loss.item() - step_loss_across_gpus = dist_mean(step_loss_per_gpu).item() - interval_loss_per_gpu = np.mean(step_losses).item() - interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item() + step_loss_per_gpu = loss + step_loss_across_gpus = dist_mean(step_loss_per_gpu) + interval_loss_per_gpu = np.mean(step_losses) + interval_loss_across_gpus = dist_mean(interval_loss_per_gpu) stacked_tok_loss = torch.cat(step_tok_losses, dim=0) - interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item() + interval_total_tok_loss_per_gpu = stacked_tok_loss.sum() interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 - ).item() - interval_total_n_bytes_per_gpu = n_bytes.item() + ) + interval_total_n_bytes_per_gpu = n_bytes interval_total_n_bytes_across_gpus = dist_sum( n_bytes, reduce_dtype=torch.bfloat16 - ).item() + ) interval_bpb_per_gpu = ( interval_total_tok_loss_per_gpu @@ -645,18 +652,20 @@ def train(args: TrainArgs): }, "memory": gpu_mem_stats._asdict(), "loss": { - "step_per_gpu": step_loss_per_gpu, - "step_across_gpu": step_loss_across_gpus, - "interval_per_gpu": interval_loss_per_gpu, - "interval_across_gpu": interval_loss_across_gpus, + "step_per_gpu": to_py_num(step_loss_per_gpu), + "step_across_gpu": to_py_num(step_loss_across_gpus), + "interval_per_gpu": to_py_num(interval_loss_per_gpu), + "interval_across_gpu": to_py_num(interval_loss_across_gpus), }, "bpb": { - "interval_per_gpu": interval_bpb_per_gpu, - "interval_across_gpus": interval_bpb_across_gpus, + "interval_per_gpu": to_py_num(interval_bpb_per_gpu), + "interval_across_gpus": to_py_num(interval_bpb_across_gpus), }, "n_bytes": { - "interval_per_gpu": interval_total_n_bytes_per_gpu, - "interval_across_gpus": interval_total_n_bytes_across_gpus, + "interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu), + "interval_across_gpus": to_py_num( + interval_total_n_bytes_across_gpus + ), }, } @@ -676,8 +685,8 @@ def train(args: TrainArgs): logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}" - f" loss_avg: {round(interval_loss_across_gpus, 4):>7}" + f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}" + f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}" f" bpb_gpu: {interval_bpb_per_gpu:3f}" f" bpb_avg: {interval_bpb_across_gpus:3f}" f" grad: {grad_norm:.2e}" @@ -702,6 +711,9 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) saved = checkpoint.save( model, optimizer, @@ -743,6 +755,9 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, @@ -754,6 +769,9 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, From c49e25171e269575ad86a570777d145efeaecc7c Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Fri, 14 Feb 2025 11:16:49 -0800 Subject: [PATCH 2/4] Update README.md (#58) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 184cf1f..c5a1fe7 100644 --- a/README.md +++ b/README.md @@ -63,11 +63,11 @@ Now launch a debug job to check if everything works. **The provided configuratio ```bash # stool stands for SLURM tool ! -python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition= +python -m bytelatent.stool script=bytelatent.train config=bytelatent/configs/debug.yaml nodes=1 partition= # if you want to launch locally you can use torchrun -torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml +torchrun --nproc-per-node 8 -m bytelatent.train config=bytelatent/configs/debug.yaml # or you can also launch on 1 GPU -python -m bytelatent.train config=apps/bytelatent/configs/debug.yaml +python -m bytelatent.train config=bytelatent/configs/debug.yaml ``` When using `stool`, if a job crashes, it can be relaunched using sbatch: From f3e8125f7407581e841dd7bec9ab4c12138f8505 Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Fri, 14 Feb 2025 11:22:03 -0800 Subject: [PATCH 3/4] using apex rmsnorm (#57) * using apex rmsnorm * added message for missing apex * black * missed a print --------- Co-authored-by: Srini Iyer --- bytelatent/base_transformer.py | 39 ++++++-------------------- bytelatent/model/latent_transformer.py | 11 ++++++-- bytelatent/model/local_models.py | 9 +++++- bytelatent/transformer.py | 9 +++++- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 217224f..7b76b9e 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -17,6 +17,14 @@ from xformers.ops import AttentionBias, fmha from bytelatent import probe from bytelatent.tokenizers.constants import EOS_ID +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") + RMSNorm = nn.RMSNorm + if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) else: @@ -294,37 +302,6 @@ class RotaryEmbedding(torch.nn.Module): return self.freqs_cis[0:seqlen] -class RMSNorm(nn.Module): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor): - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor): - x = probe.log_stats(x, "resid") - output = self._norm(x.float()) - return (output * self.weight.float()).type_as(x) - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) # type: ignore - - def _reshape_for_attn_bias( attn_bias: AttentionBias | None, *tensors: torch.Tensor, diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index d91f49f..95b6d8b 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -12,12 +12,19 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, BaseTransformerArgs, - RMSNorm, flex_attention_comp, repeat_kv, ) from bytelatent.model.utils import create_causal_mask +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") + RMSNorm = nn.RMSNorm + logger = logging.getLogger() @@ -44,7 +51,7 @@ class CrossAttention(nn.Module): self.n_kv_heads = n_kv_heads self.heads_per_group = self.n_heads // self.n_kv_heads - self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) self.wq = nn.Linear( diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index d92a1fb..353c878 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -14,7 +14,6 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformerArgs, InitStdFactor, - RMSNorm, RotaryEmbedding, TransformerBlock, ) @@ -22,6 +21,14 @@ from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") + RMSNorm = nn.RMSNorm + logger = logging.getLogger() diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index b65e502..2e45ea5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -19,11 +19,18 @@ from xformers.ops import AttentionBias, fmha from bytelatent.base_transformer import ( BaseTransformer, BaseTransformerArgs, - RMSNorm, cross_entropy, ) from bytelatent.model.utils import create_causal_mask +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") + RMSNorm = nn.RMSNorm + def attention_flops_per_token(n_layers, seq_len, dim, causal): # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 From bec016482091e1e7442907a3a39b20a56b527a3a Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 14 Feb 2025 21:03:56 +0000 Subject: [PATCH 4/4] Make it possible to specify multiple config files Summary: Test Plan: Test that this iterpolates in the right order, config -> configs -> cli args ``` # All three sources python -m bytelatent.print_config config=bytelatent/configs/debug.yaml configs=[internal/configs/s3_debug.yaml] eval=null # What worked before python -m bytelatent.print_config config=internal/configs/s3_debug.yaml eval=null ``` --- bytelatent/args.py | 14 -- bytelatent/config_parser.py | 70 ++++++++++ bytelatent/configs/debug.yaml | 4 +- bytelatent/configs/entropy_model.yaml | 4 +- bytelatent/eval.py | 8 +- bytelatent/print_config.py | 11 ++ bytelatent/test_config_parser.py | 180 ++++++++++++++++++++++++++ bytelatent/train.py | 5 +- fixtures/test-cfgs/list.yaml | 1 + fixtures/test-cfgs/middle.yaml | 3 + fixtures/test-cfgs/override.yaml | 1 + fixtures/test-cfgs/root.yaml | 6 + fixtures/test-cfgs/top.yaml | 3 + 13 files changed, 283 insertions(+), 27 deletions(-) create mode 100644 bytelatent/config_parser.py create mode 100644 bytelatent/print_config.py create mode 100644 bytelatent/test_config_parser.py create mode 100644 fixtures/test-cfgs/list.yaml create mode 100644 fixtures/test-cfgs/middle.yaml create mode 100644 fixtures/test-cfgs/override.yaml create mode 100644 fixtures/test-cfgs/root.yaml create mode 100644 fixtures/test-cfgs/top.yaml diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..dd1fef5 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,7 +5,6 @@ from typing import Any import numpy as np import yaml -from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state -def parse_args(args_cls): - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - pydantic_args = args_cls.model_validate(cfg) - return pydantic_args - - TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" diff --git a/bytelatent/config_parser.py b/bytelatent/config_parser.py new file mode 100644 index 0000000..6e60972 --- /dev/null +++ b/bytelatent/config_parser.py @@ -0,0 +1,70 @@ +import copy +from typing import Type + +import omegaconf +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel + + +def parse_file_config(path: str) -> DictConfig: + file_cfg = OmegaConf.load(path) + if not isinstance(file_cfg, DictConfig): + raise ValueError( + f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" + ) + return file_cfg + + +def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: + if "config" not in cfg: + return [cfg] + + ordered_cfgs = [] + cfg = copy.deepcopy(cfg) + config_arg = cfg["config"] + del cfg["config"] + ordered_cfgs.append(cfg) + + if isinstance(config_arg, str): + file_cfg = parse_file_config(config_arg) + sub_configs = recursively_parse_config(file_cfg) + ordered_cfgs = sub_configs + ordered_cfgs + elif isinstance(config_arg, omegaconf.listconfig.ListConfig): + sub_configs = [] + for c in config_arg: + if not isinstance(c, str): + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' + ) + config_to_parse = parse_file_config(c) + sub_configs.extend(recursively_parse_config(config_to_parse)) + ordered_cfgs = sub_configs + ordered_cfgs + else: + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' + ) + return ordered_cfgs + + +def parse_args_with_default( + *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None +): + if cli_args is None: + cli_args = OmegaConf.from_cli() + assert isinstance( + cli_args, DictConfig + ), f"CLI Args must be a DictConfig, not {type(cli_args)}" + ordered_cfgs = recursively_parse_config(cli_args) + if default_cfg is not None: + ordered_cfgs.insert(0, default_cfg) + cfg = OmegaConf.merge(*ordered_cfgs) + return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + +def parse_args_to_pydantic_model( + args_cls: Type[BaseModel], cli_args: DictConfig | None = None +): + default_cfg = OmegaConf.create(args_cls().model_dump()) + parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) + pydantic_args = args_cls.model_validate(parsed_cfg) + return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index d7c27b7..79cc85b 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -2,9 +2,10 @@ # Evals can be activated by uncommenting its config # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest -dump_dir: /tmp/ +dump_dir: /tmp/blt-entropy name: "debug" steps: 100_000 +max_steps: null probe_freq: null seed: 777 optim: @@ -35,7 +36,6 @@ entropy_model: attn_impl: "xformers" data: - s3_profile: blt root_dir: ??? sources: dclm_baseline_1.0: 1.0 diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..50e17cd 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -5,18 +5,15 @@ import logging import os from collections import defaultdict from datetime import datetime -from pathlib import Path -from typing import Any import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from omegaconf import OmegaConf -from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, @@ -29,7 +26,6 @@ from bytelatent.generate import ( PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..0bc99e7 --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,11 @@ +from bytelatent.args import TrainArgs +from bytelatent.config_parser import parse_args_to_pydantic_model + + +def main(): + train_args = parse_args_to_pydantic_model(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/test_config_parser.py b/bytelatent/test_config_parser.py new file mode 100644 index 0000000..c1ec99b --- /dev/null +++ b/bytelatent/test_config_parser.py @@ -0,0 +1,180 @@ +import os + +import pytest +from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf +from pydantic import BaseModel, ConfigDict + +from bytelatent.config_parser import ( + parse_args_to_pydantic_model, + parse_file_config, + recursively_parse_config, +) + +FIXTURE_DIR = "fixtures/test-cfgs" + + +def test_parse_file_config(): + with pytest.raises(ValueError): + cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) + assert isinstance(cfg, DictConfig) + + +def test_nop(): + cfg = OmegaConf.create({"a": 1}) + parsed_cfgs = recursively_parse_config(cfg) + assert len(parsed_cfgs) == 1 + assert parsed_cfgs[0] == cfg + + +def test_root(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 2 + assert len(parsed_cfgs[1]) == 0 + assert parsed_cfgs[0]["seed"] == -1 + with pytest.raises(MissingMandatoryValue): + assert parsed_cfgs[0]["b"]["y"] is not None + + # Test basic cli override + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert parsed_cfgs[1]["seed"] == 42 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["seed"] == 42 + + +def test_one_level_include(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert len(parsed_cfgs[2]) == 0 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 10 + + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["b"]["y"] == 100 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 100 + + +def test_two_level_include(): + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 4 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["p"] == 500 + assert parsed_cfgs[3]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +def test_multiple_includes(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 100 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + "a": 1000, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1000 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +class SubConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + x: int = -100 + y: int = -100 + z: int = -5 + + +class SampleConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + a: int = -100 + seed: int = -100 + b: SubConfig = SubConfig() + hello: str = "" + p: int = -100 + + +def test_pydantic_parse(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "a": 1000, + } + ) + cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) + assert isinstance(cfg, SampleConfig) + assert cfg.a == 1000 + assert cfg.p == 500 + assert cfg.seed == -1 + assert cfg.b.x == 0 + assert cfg.b.y == 10 + assert cfg.b.z == -5 + assert cfg.hello == "world" diff --git a/bytelatent/train.py b/bytelatent/train.py index 3669167..ad74b44 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -23,8 +23,9 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs, parse_args +from bytelatent.args import TrainArgs from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( @@ -824,7 +825,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - train_args = parse_args(TrainArgs) + train_args = parse_args_to_pydantic_model(TrainArgs) if train_args.debug_dynamo: import torch._dynamo diff --git a/fixtures/test-cfgs/list.yaml b/fixtures/test-cfgs/list.yaml new file mode 100644 index 0000000..b5d8bb5 --- /dev/null +++ b/fixtures/test-cfgs/list.yaml @@ -0,0 +1 @@ +[1, 2, 3] diff --git a/fixtures/test-cfgs/middle.yaml b/fixtures/test-cfgs/middle.yaml new file mode 100644 index 0000000..a476d8d --- /dev/null +++ b/fixtures/test-cfgs/middle.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/root.yaml +b: + y: 10 diff --git a/fixtures/test-cfgs/override.yaml b/fixtures/test-cfgs/override.yaml new file mode 100644 index 0000000..456df7b --- /dev/null +++ b/fixtures/test-cfgs/override.yaml @@ -0,0 +1 @@ +a: 100 diff --git a/fixtures/test-cfgs/root.yaml b/fixtures/test-cfgs/root.yaml new file mode 100644 index 0000000..dc4d285 --- /dev/null +++ b/fixtures/test-cfgs/root.yaml @@ -0,0 +1,6 @@ +seed: -1 +a: 1 +b: + x: 0 + y: ??? + z: ??? diff --git a/fixtures/test-cfgs/top.yaml b/fixtures/test-cfgs/top.yaml new file mode 100644 index 0000000..632866c --- /dev/null +++ b/fixtures/test-cfgs/top.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/middle.yaml + +hello: world