From 7622d28b749b632498f5c9ecddd46e0e099ab1df Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 27 Jan 2025 09:46:44 -0800 Subject: [PATCH 1/2] Initial codes and scripts for training entropy model (#34) Summary: Test Plan: --- .gitignore | 1 + bytelatent/args.py | 13 ++- bytelatent/configs/debug.yaml | 3 +- bytelatent/configs/entropy_model.yaml | 82 +++++++++++++++++++ bytelatent/data/data_types.py | 2 +- bytelatent/data/iterators/packing_iterator.py | 42 ++++++++++ .../data/iterators/sequence_iterator.py | 30 +++++-- bytelatent/data/patcher.py | 10 ++- bytelatent/model/blt.py | 5 +- bytelatent/test_blt.py | 3 +- bytelatent/train.py | 52 +++++++++--- 11 files changed, 209 insertions(+), 34 deletions(-) create mode 100644 bytelatent/configs/entropy_model.yaml diff --git a/.gitignore b/.gitignore index 6c664b8..d1d7c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ figures/ .vscode/ .DS_Store internal/ +jobs_parallel-copy/ diff --git a/bytelatent/args.py b/bytelatent/args.py index a332c89..56de22d 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel): max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False + add_patches: bool = True + tokenizer_args: TokenizerArgs = TokenizerArgs() patcher_args: PatcherArgs = PatcherArgs() @@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel): looping_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, + add_patches=self.add_patches, ) sequence_iterator = SequenceIterator( preprocess_iterator, @@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel): source_to_iterator=source_to_sequence_iterators, ) tokenizer = self.tokenizer_args.build() + if self.tokenizer_args.name == "bytes": + # TODO: Check this with Artidoro + pad_id = 0 + else: + pad_id = tokenizer.boe_id packing_args = PackingArgs( batch_size=self.batch_size, seq_len=self.seq_len, - pad_id=tokenizer.boe_id, + pad_id=pad_id, max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, + tokenizer_name=self.tokenizer_args.name, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: @@ -180,7 +189,7 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() - model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs() # This is only needed for training the entropy model entropy_model: LMTransformerArgs | None = None # Instead of training main model, train entropy model diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 4ae4459..1098ff5 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -26,10 +26,9 @@ model: vocab_size: 260 dim_token: 256 patch_size: 6 - tokenization_mode: "bytes" patching_mode: "space" tie_local_encoder_decoder_logits: false - data_loader_patching: true + patch_in_forward: false max_encoder_seq_length: 12288 pad_to_max_length: true patching_threshold: 3.1439168453216553 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml new file mode 100644 index 0000000..51b65d4 --- /dev/null +++ b/bytelatent/configs/entropy_model.yaml @@ -0,0 +1,82 @@ +# Template config, need to change dump_dir, data.root_dir and tokenizer.path +# Evals can be activated by uncommenting its config +# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest + +dump_dir: /tmp/ +name: "debug" +steps: 100_000 +probe_freq: null +seed: 777 +optim: + lr: 4e-04 + warmup: 500 + lr_min_ratio: 0.1 + clip: 10.0 + +distributed: + fsdp_type: full_shard + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +train_entropy_model: true +model: null +entropy_model: + dim: 768 + n_layers: 14 + n_heads: 12 + max_seqlen: 8192 + # vocab_size: -1 + vocab_size: 260 + ffn_dim_multiplier: 1.0 + sliding_window: 512 + attn_bias_type: "local_block_causal" + attn_impl: "xformers" + +data: + s3_profile: blt + root_dir: ??? + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 64 + # seqlen is in terms of patches and + # max_encoder_seq_length is in terms of bytes. + # For entropy model, these are the same since 1 patch=1 byte + seq_len: 8192 + max_encoder_seq_length: 8192 + load_async: true + preprocess_dir: ??? + # We don't need patches for this model + add_patches: false + patcher_args: + # This doesn't matter since byte entropy model doesn't use patching, + # so pick the most efficient, so static + patching_mode: byte + tokenizer_args: + name: bytes + +profiling: + run: false + +checkpoint: + dump: + every: 500 + keep: 3 + eval: + every: 1000 + keep: -1 + +logging: + freq: 10 + +eval_on_gpus: 8 +eval: + dataset_dir: ??? + tasks: ??? + generator: + max_tokens: 65536 + dtype: bf16 + + mp_size: 1 diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index 7e142e4..aa2daa9 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]] class BltSequence(BaseModel): tokens: list[int] mask: list[bool] - patch_lengths: list[int] + patch_lengths: list[int] | None @dataclass diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 361fc03..fa29149 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -17,6 +17,7 @@ class PackingArgs(BaseModel): max_length: int | None pad_to_max_length: bool enable_byte_ngrams: bool + tokenizer_name: str class PackingIteratorState(BaseModel, IteratorState): @@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ) def create_iter(self): + if self.packing_args.tokenizer_name == "bytes": + return self._create_iter_from_bytes() + else: + return self._create_iter_from_patch_lengths() + + def _create_iter_from_bytes(self): + sequence_iter = self.sequence_iterator.create_iter() + batch_size = self.packing_args.batch_size + pad_id = self.packing_args.pad_id + seq_len = self.packing_args.seq_len + while True: + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + + x = np.full((batch_size, seq_len), fill_value=pad_id) + y = np.full((batch_size, seq_len), fill_value=pad_id) + + for i, tok_seq in enumerate(tokens): + x[i, : len(tok_seq)] = tok_seq + y[i, : len(tok_seq) - 1] = tok_seq[1:] + batch = Batch(x=x, y=y) + assert ( + batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() + ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + yield batch + + def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id @@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): _tokens = sequence.tokens _mask = sequence.mask _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches assert len(sequence.patch_lengths) == self.packing_args.seq_len last_patch_length = 0 if _patch_lengths[0] > 1: diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index 14e3747..d90ea31 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator): for example in example_iter: assert example.tokens is not None assert example.mask is not None - assert example.patch_lengths is not None + if self.preprocess_iterator.add_patches: + assert example.patch_lengths is not None + assert len(example.tokens) == sum(example.patch_lengths) + else: + assert example.patch_lengths is None assert len(example.tokens) != 0 assert len(example.mask) != 0 assert len(example.tokens) == len(example.mask) - assert len(example.tokens) == sum(example.patch_lengths) tokens.extend(example.tokens) mask.extend(example.mask) - patch_lengths.extend(example.patch_lengths) + if self.preprocess_iterator.add_patches: + patch_lengths.extend(example.patch_lengths) + else: + # This lets the rest of the code work as expected and just yield byte seqs + patch_lengths.extend([1] * len(example.tokens)) while len(patch_lengths) >= n_buffer_patches: if first: @@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator): == len(seq_mask[idx]) ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" - yield BltSequence( - tokens=seq_tokens[idx], - mask=seq_mask[idx], - patch_lengths=seq_patch_lengths[idx], - ) + if self.preprocess_iterator.add_patches: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=seq_patch_lengths[idx], + ) + else: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=None, + ) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index afcfa2e..44ff5e9 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum): bpe = "bpe" bpe_patcher = "bpe_patcher" space = "space" + static = "static" + byte = "byte" class PatcherArgs(BaseModel): @@ -34,7 +36,6 @@ class PatcherArgs(BaseModel): max_patch_length: int | None = None patch_size: float = 4.5 patching_batch_size: int = 1 - data_loader_patching: bool = False device: str = "cuda" monotonicity: bool = False log_time: bool = False @@ -486,7 +487,6 @@ class Patcher: self.max_patch_length = patcher_args.max_patch_length self.patch_size = patcher_args.patch_size self.patching_batch_size = patcher_args.patching_batch_size - self.data_loader_patching = patcher_args.data_loader_patching self.device = patcher_args.device self.monotonicity = patcher_args.monotonicity self.log_time = patcher_args.log_time @@ -528,7 +528,7 @@ class Patcher: seq_len_next_tok = seq_len + 1 if include_next_token else seq_len scores = None # STATIC - if self.patching_mode is None: + if self.patching_mode == PatchingModeEnum.static: patch_lengths = torch.zeros( (bs, math.ceil(seq_len_next_tok / self.patch_size)), dtype=tokens.dtype, @@ -536,6 +536,10 @@ class Patcher: ).fill_(self.patch_size) if seq_len_next_tok % self.patch_size != 0: patch_lengths[:, -1] = seq_len_next_tok % self.patch_size + elif self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device + ) # ENTROPY elif self.patching_mode == PatchingModeEnum.entropy: if self.log_time: diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 843ad34..a62be23 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False + patch_in_forward: bool = False # Architecture and dimensions dim_token: int = 256 @@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_layers_local_encoder: int = 8 # Tokenization and patching - tokenization_mode: str = "bpe" patch_size: float | None = None patching_mode: str | None = None patching_threshold: float | None = None @@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): monotonicity: bool = False patching_batch_size: int = 1 patching_device: str = "cuda" - data_loader_patching: bool = False max_patch_length: int | None = None # Encoder/Decoder configuration @@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module): self.output.weight = self.tok_embeddings.weight # Patcher module - if not args.data_loader_patching: + if args.patch_in_forward: self.patcher = Patcher( PatcherArgs( patch_size=args.patch_size, diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 36a9882..eb94df3 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -68,10 +68,9 @@ def create_args(cross_attention=False): # Additional args from command line dim_token=256, patch_size=6, - tokenization_mode="bytes", patching_mode="space", tie_local_encoder_decoder_logits=False, - data_loader_patching=True, + patch_in_forward=False, max_encoder_seq_length=12288, pad_to_max_length=True, encoder_lm_loss=False, diff --git a/bytelatent/train.py b/bytelatent/train.py index 80bd393..1d0fa40 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler from bytelatent.stool import StoolArgs, launch_job from bytelatent.transformer import ( + LMTransformer, build_fsdp_grouping_plan, get_no_recompute_ops, get_num_flop_per_token, @@ -103,10 +104,15 @@ class TrainState(Stateful): def validate_train_args(args: TrainArgs, output_size: int): - if args.model.vocab_size < 0: + assert args.model is not None or args.entropy_model is not None + if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + assert args.dump_dir, "Dump dir not set" if args.checkpoint.path is None: @@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int): and args.distributed.dp_replicate == get_world_size() ) - args.model.max_seqlen = args.data.seq_len + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len if args.distributed.tp_size == 1: logger.warning( @@ -237,7 +246,14 @@ def train(args: TrainArgs): # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory with torch.device("meta"): - model = ByteLatentTransformer(args.model) + if args.train_entropy_model: + assert args.entropy_model is not None + model = LMTransformer(args.entropy_model) + model_args = args.entropy_model + else: + assert args.model is not None + model = ByteLatentTransformer(args.model) + model_args = args.model logger.info("Model is built !") model_param_count = get_num_params(model) @@ -247,7 +263,7 @@ def train(args: TrainArgs): world_mesh, args.model, args.distributed, - fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + fsdp_grouping_plan=build_fsdp_grouping_plan(model_args), tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) @@ -267,7 +283,7 @@ def train(args: TrainArgs): model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: with torch.random.fork_rng(devices=[torch.cuda.current_device()]): - torch.manual_seed(args.model.seed) + torch.manual_seed(model_args.seed) model.init_weights() check_model_value_range(model, range=10.0, std=1.0) @@ -342,10 +358,17 @@ def train(args: TrainArgs): batch.x, ).cuda() batch_y = torch.from_numpy(batch.y).cuda() - batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() + if batch.patch_lengths is None: + batch_patch_lengths = None + else: + batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() - if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None: + if ( + not args.train_entropy_model + and args.model.encoder_enable_byte_ngrams + and batch.ngram_ids is None + ): raise ValueError( "Cannot enable byte ngrams and have batch.ngram_ids be None" ) @@ -408,9 +431,12 @@ def train(args: TrainArgs): next(probe_mod.parameters()).grad is None ), "Probe model shouldn't have grads at this point" - pred = model( - batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids - ) + if args.train_entropy_model: + pred = model(batch_x) + else: + pred = model( + batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids + ) loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) @@ -474,9 +500,9 @@ def train(args: TrainArgs): # Use xformer's analyze profile trace to get actual measurement FLOPS = ( get_num_flop_per_token( - model_param_count - args.model.vocab_size * args.model.dim, - args.model.n_layers, - args.model.dim, + model_param_count - model_args.vocab_size * model_args.dim, + model_args.n_layers, + model_args.dim, args.data.seq_len, ) * wps From e02ba763b00aa7075752a3615f59243c210ed188 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jan 2025 00:38:16 +0000 Subject: [PATCH 2/2] This includes fixes that make checkpointing and reloading work correctly. It also batches in a first set of changes for fixing eval code Summary: Test Plan: --- apps/main/lingua_train.py | 2 +- bytelatent/args.py | 87 ++++++++- bytelatent/checkpoint.py | 9 +- bytelatent/data/data_types.py | 10 - .../data/iterators/multiprocess_iterator.py | 15 ++ {apps/main => bytelatent}/eval.py | 179 ++++-------------- {apps/main => bytelatent}/generate.py | 57 +++--- bytelatent/train.py | 81 ++++---- 8 files changed, 219 insertions(+), 221 deletions(-) rename {apps/main => bytelatent}/eval.py (56%) rename {apps/main => bytelatent}/generate.py (91%) diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py index bdb47da..7925ec6 100644 --- a/apps/main/lingua_train.py +++ b/apps/main/lingua_train.py @@ -544,7 +544,7 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval eval_args = dataclass_from_dict(EvalArgs, args.eval) diff --git a/bytelatent/args.py b/bytelatent/args.py index 56de22d..d1bac46 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,6 +5,7 @@ from typing import Any import numpy as np import yaml +from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state +def parse_args(args_cls): + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.create(args_cls().model_dump()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + pydantic_args = args_cls.model_validate(cfg) + return pydantic_args + + def distribute_data_to_rank( *, dataset_path: str, @@ -71,6 +85,22 @@ def distribute_data_to_rank( return rank_to_arrow_iterator_params[rank] +class PackedCausalTransformerGeneratorArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + temperature: float = 0.0 + top_p: float | None = None + top_k: float | None = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: int | None = None + until: list[str] = [] + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: str | None = "bf16" + device: str | None = "cuda" + + class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") s3_profile: str | None = None @@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel): return packing_iterator +class LMHarnessArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + tasks: list[Any] | None = None + num_fewshot: int | None = None + device: str | None = None + use_cache: str | None = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: int | float | None = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: str | None = None + apply_chat_template: bool | str = False + fewshot_as_multiturn: bool = False + gen_kwargs: str | None = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +class ValidationArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + max_steps: int | None = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) + use_val_from_train_src: bool = True # Use the validation set from training sources + root_dir: str = "" + sources: list[str] = [] # Other sources to eval on + + +class EvalArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dump_dir: str + ckpt_dir: str + metric_log_dir: str | None = None + generator: PackedCausalTransformerGeneratorArgs = ( + PackedCausalTransformerGeneratorArgs() + ) + + harness: LMHarnessArgs | None = LMHarnessArgs() + validation: ValidationArgs | None = ValidationArgs() + + global_step: int | None = None # for in-training evaluation + s3_profile: str | None = None + + class TrainArgs(BaseModel): model_config = ConfigDict(extra="forbid") name: str = "lingua" @@ -186,6 +268,9 @@ class TrainArgs(BaseModel): # Nb optimizer steps to take steps: int = 1000 + # If not None, halt training after this many steps, + # useful for debugging + max_steps: int | None = None data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() @@ -203,7 +288,7 @@ class TrainArgs(BaseModel): # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: int | None = None - eval: Any | None = None + eval: EvalArgs | None = None eval_on_gpus: int | None = None def dump_to_yaml_file( diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index bcf591e..f213c84 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -7,6 +7,7 @@ import re from pathlib import Path from typing import List, Optional, Tuple +import fsspec import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import ( set_state_dict, ) +from bytelatent.data.file_util import get_fs from bytelatent.distributed import get_is_master logger = logging.getLogger("CHECKPOINT") @@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel): path: str | None = None init_ckpt_path: str | None = None continue_training_from_init: bool = False + s3_profile: str | None = None def _get_key_step(name: str): return int(re.findall(RE_DIGITS, name)[-1]) -def consolidate_checkpoints(ckpt_dir: str): +def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): """ Consolidates all FSDP checkpoints in a directory to a single file Consolidate checkpoint is saved in a subdirectory of ckpt_dir @@ -102,15 +105,17 @@ def load_from_checkpoint( dcp.load(state_dict, checkpoint_id=ckpt_dir) +# TODO: Rewrite the file operations here to use fsspec to enable s3 writing. class CheckpointManager: def __init__(self, args: CheckpointArgs): self.path = args.path + self.fs = get_fs(self.path, s3_profile=args.s3_profile) self.dump_every = args.dump self.eval_every = args.eval self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert os.path.exists( + assert self.fs.exists( self.path ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index aa2daa9..f4bbc07 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel): n_views: int = 2 -class DataLoaderState(BaseModel): - model_config = ConfigDict(extra="forbid") - multi_choice_state: MultiChoiceState - pack_tokens_state: BltPackTokensState - prefetch_state: PrefetchState - - -BltIterator = Iterator[tuple[BltExample, DataLoaderState]] - - class BltSequence(BaseModel): tokens: list[int] mask: list[bool] diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index f17ca6e..49d99ac 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator): self.producer = None self.stop_iterating_event = None self.state_dumped_event = None + self.force_shutdown = False + + def shutdown(self): + if self.producer is not None: + # This properly shuts things down + self.producer.kill() + self.force_shutdown = True def get_state(self) -> MultiprocessIteratorState: """ @@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator): to halt the background process and allow it to write the state to the main loop in order to not lose data """ + if self.force_shutdown: + raise ValueError( + "State will be invalid if shutdown was forced before state persisted." + ) if self.producer is None: serialized_prefetch_buffer = json.dumps( [b.to_python_dict() for b in self.prefetch_buffer] @@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator): ) def create_iter(self): + if self.force_shutdown: + raise ValueError( + "Iterator may be invalid if shutdown was forced before state persisted." + ) logging.info("Main thread: Creating MP iterator") # First yield from the stored prefetch buffer. if self.prefetch_buffer is not None: diff --git a/apps/main/eval.py b/bytelatent/eval.py similarity index 56% rename from apps/main/eval.py rename to bytelatent/eval.py index ed20f49..ae73066 100644 --- a/apps/main/eval.py +++ b/bytelatent/eval.py @@ -4,20 +4,20 @@ import json import logging import os from collections import defaultdict -from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any import torch -from lingua.args import dump_config -from lingua.data import init_choice_state, setup_sources from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from omegaconf import OmegaConf +from pydantic import BaseModel, ConfigDict +from bytelatent.args import EvalArgs, ValidationArgs, parse_args from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -25,72 +25,17 @@ from bytelatent.distributed import ( get_world_size, setup_torch_distributed, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs - -from apps.main.generate import ( +from bytelatent.generate import ( PackedCausalTransformerGenerator, - PackedCausalTransformerGeneratorArgs, load_consolidated_model_and_tokenizer, ) +from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" logger = logging.getLogger() -@dataclass -class LMHarnessArgs: - tasks: Optional[List[Any]] = None - num_fewshot: Optional[int] = None - device: Optional[str] = None - use_cache: Optional[str] = None - cache_requests: bool = False - rewrite_requests_cache: bool = False - delete_requests_cache: bool = False - limit: Optional[Union[int, float]] = None - bootstrap_iters: int = 100000 - check_integrity: bool = False - write_out: bool = False - log_samples: bool = True - system_instruction: Optional[str] = None - apply_chat_template: Union[bool, str] = False - fewshot_as_multiturn: bool = False - gen_kwargs: Optional[str] = None - verbosity: str = "INFO" - predict_only: bool = False - random_seed: int = 0 - numpy_random_seed: int = 1234 - torch_random_seed: int = 1234 - fewshot_random_seed: int = 1234 - - -@dataclass -class ValidationArgs: - max_steps: Optional[int] = ( - None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) - ) - use_val_from_train_src: bool = True # Use the validation set from training sources - root_dir: str = "" - sources: List[str] = field(default_factory=list) # Other sources to eval on - - -@dataclass -class EvalArgs: - name: str = "evals" - dump_dir: Optional[str] = None - metric_log_dir: Optional[str] = None - ckpt_dir: str = "" - generator: PackedCausalTransformerGeneratorArgs = field( - default_factory=PackedCausalTransformerGeneratorArgs - ) - harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) - validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs) - - wandb: Optional[Any] = None - - global_step: Optional[int] = None # for in-training evaluation - - def all_dicts_same(dict_list): if not dict_list: # Check if the list is empty return True @@ -120,7 +65,7 @@ class EvalHarnessLM(LM): self._world_size = get_world_size() self.device = generator.device - def generate_until(self, requests: List[Instance]) -> List[str]: + def generate_until(self, requests: list[Instance]) -> list[str]: prompts, gen_args = zip(*[req.args for req in requests]) assert all_dicts_same(gen_args), "Doesn't support different gen args for now" gen_args = gen_args[0] @@ -141,7 +86,7 @@ class EvalHarnessLM(LM): filtered_gen.append(g) return filtered_gen - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: prompts, continuations = zip(*[req.args for req in requests]) inputs = [req.args[0] + req.args[1] for req in requests] max_gen_len = self.generator.max_gen_len @@ -158,7 +103,7 @@ class EvalHarnessLM(LM): self.generator.max_gen_len = max_gen_len return results - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: prompts = [req.args[0] for req in requests] max_gen_len = self.generator.max_gen_len # We temporarily lower max gen len @@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): return all_val_metrics -def launch_eval(cfg: EvalArgs): +def launch_eval(eval_args: EvalArgs): if not torch.distributed.is_initialized(): setup_torch_distributed(DistributedArgs()) + + fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) if ( - Path(cfg.ckpt_dir).exists() - and (Path(cfg.ckpt_dir) / "params.json").exists() - and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + fs.exists(eval_args.ckpt_dir) + and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) + and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 ): - consolidate_path = Path(cfg.ckpt_dir) + consolidate_path = eval_args.ckpt_dir else: - consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER - if not consolidate_path.exists() and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) + if not fs.exists(consolidate_path) and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) - Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) - dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + fs.mkdirs(eval_args.dump_dir, exist_ok=True) + with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: + f.write(eval_args.model_dump_json()) - consolidate_path = str(consolidate_path) torch.distributed.barrier() logger.info("Loading model") + # TODO: Make this general so that it works with either + # LMTransformer or Blt, similar with args model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ) logger.info("Model loaded") model.eval() - generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) wrap = EvalHarnessLM(generator) - results = simple_evaluate(wrap, **asdict(cfg.harness)) + # Redo + results = simple_evaluate(wrap, eval_args.harness.model_dump()) val_results = None - if cfg.validation: - val_results = eval_on_val(generator, cfg.validation, train_cfg) + if eval_args.validation: + val_results = eval_on_val(generator, eval_args.validation, train_cfg) if get_global_rank() == 0: - with open(Path(cfg.dump_dir) / "results.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) logger.info(f"All evaluation results: {results['results']}") if val_results is not None: - with open(Path(cfg.dump_dir) / "validation.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") - if cfg.metric_log_dir and get_global_rank() == 0: - metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + if eval_args.metric_log_dir and get_global_rank() == 0: + metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") logger.info(f"Writing metric logs to {metric_log_path}") timestamp = { "created_at": datetime.utcnow().isoformat(), } - if cfg.global_step is not None: - timestamp["global_step"] = cfg.global_step + if eval_args.global_step is not None: + timestamp["global_step"] = eval_args.global_step print( json.dumps(timestamp | results["results"]), - file=open(metric_log_path, mode="a"), + file=fs.open(metric_log_path, mode="a"), flush=True, ) - val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" + val_log_path = os.path.join( + eval_args.metric_log_dir, "metrics.validation.jsonl" + ) if val_results is not None: print( json.dumps(timestamp | val_results), - file=open(val_log_path, mode="a"), + file=fs.open(val_log_path, mode="a"), flush=True, ) @@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs): def main(): - """ - The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments - This accepts arguments as a dot list - So if the dataclass looks like - - @dataclass - class DummyArgs: - name: str - model: LMTransformerArgsgs - - @dataclass - class LMTransformerArgsgs: - dim: int - - Then you can pass model.dim=32 to change values in LMTransformerArgsgs - or just name=tictac for top level attributes. - - The behavior here is as follows: - 1. We instantiate EvalArgs with its default values - 2. We override those default values with the ones in the provided config file - 3. We override the result with the additional arguments provided through command line - - For example, if the config is the following - - model: - dim: 128 - n_layers: 4 - - and you call eval.py with eval.py model.dim=64 - - Then the final TrainArgs will have - - model: - dim: 64 - n_layers: 4 - - Plus all the default values in EvalArgs dataclass. - """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.structured(EvalArgs()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_object(cfg) - launch_eval(cfg) + eval_args = parse_args(EvalArgs) + launch_eval(eval_args) if __name__ == "__main__": diff --git a/apps/main/generate.py b/bytelatent/generate.py similarity index 91% rename from apps/main/generate.py rename to bytelatent/generate.py index a3a8627..eb79d81 100644 --- a/apps/main/generate.py +++ b/bytelatent/generate.py @@ -1,20 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Optional import torch -from lingua.args import dataclass_from_dict -from lingua.tokenizers.abstract_tokenizer import Tokenizer -from lingua.tokenizers.build_tokenizer import build_tokenizer from omegaconf import OmegaConf from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import create_block_mask from tqdm import tqdm +from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( Attention, causal_mask, @@ -23,7 +19,10 @@ from bytelatent.base_transformer import ( lengths_to_start_ids, ) from bytelatent.checkpoint import CONSOLIDATE_NAME -from bytelatent.transformer import LMTransformer, LMTransformerArgs +from bytelatent.data.file_util import get_fs +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.transformer import LMTransformer def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): return next_token.view(shape[:-1]) -def pack_prompts(prompts: List[int]): +def pack_prompts(prompts: list[int]): res = [] lengths = [] for i, p in enumerate(prompts): @@ -120,22 +119,6 @@ class KVCache(nn.Module): return self.k_cache, self.v_cache -@dataclass -class PackedCausalTransformerGeneratorArgs: - temperature: float = 0.0 - top_p: Optional[float] = None - top_k: Optional[float] = None - max_gen_len: int = 512 # Maximum number of tokens to generate - max_tokens: int = 1024 # Maximum number of tokens that can go through the model - max_prompt_len: Optional[int] = None - until: List[str] = field(default_factory=list) - compile_prefilling: bool = False - reduce_generation_overhead: bool = False - show_progress: bool = False - dtype: Optional[str] = "bf16" - device: Optional[str] = "cuda" - - class PackedCausalTransformerGenerator: def __init__( self, @@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator: def load_consolidated_model_and_tokenizer( consolidated_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ): - ckpt_path = Path(consolidated_path) - config = ckpt_path / "params.json" - config = OmegaConf.load(config) + train_args_path = os.path.join(consolidated_path, "params.json") + fs = get_fs(train_args_path) + with fs.open(train_args_path) as f: + train_args = TrainArgs.model_validate_json(f.read()) + + if train_args.train_entropy_model: + model_args = train_args.entropy_model + model = LMTransformer(model_args) + else: + model_args = train_args.model + model = ByteLatentTransformer(model_args) param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ - config.distributed.model_dtype + train_args.distributed.model_dtype ] - model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) - tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) - model = model_cls(model_args) - st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + tokenizer = train_args.data.tokenizer_args.build() + st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): param.data = param.data.to(dtype=param_dtype) - return model, tokenizer, config + return model, tokenizer, train_args def main(): diff --git a/bytelatent/train.py b/bytelatent/train.py index 1d0fa40..6b20ecd 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -10,7 +10,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path from timeit import default_timer as timer -from typing import Any, Dict, Type, TypeVar +from typing import Any, TypeVar import torch import torch.distributed @@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs +from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint -from bytelatent.data.data_types import DataLoaderState +from bytelatent.data.iterators.multiprocess_iterator import ( + MultiprocessIterator, + MultiprocessIteratorState, +) +from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, @@ -39,6 +43,7 @@ from bytelatent.distributed import ( setup_env, setup_torch_distributed, ) +from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer @@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"): return dict(items) -def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: - """ - Converts a dictionary to a dataclass instance, recursively for nested structures. - """ - base = OmegaConf.structured(cls()) - OmegaConf.set_struct(base, strict) - override = OmegaConf.create(data) - return OmegaConf.to_object(OmegaConf.merge(base, override)) +def get_iterator_state_name(iterator_state): + if isinstance(iterator_state, MultiprocessIteratorState): + return "multiprocess" + elif isinstance(iterator_state, PackingIteratorState): + return "packing" + else: + raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +# TODO: Make this pydantic based instead of data class based +# TODO: Generalize this to any iterator state @dataclass class TrainState(Stateful): step: int # Nb of steps taken by the optimizer acc_step: int # Nb of accumulation steps done since last optimizer step scheduler: lr_scheduler.LambdaLR - data_loader_state: DataLoaderState + data_loader_state: MultiprocessIteratorState | PackingIteratorState scale: float = 1.0 + data_loader_class: str | None = None - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "step": self.step, "acc_step": self.acc_step, - "data_loader_state": self.data_loader_state.dict(), + "data_loader_state": self.data_loader_state.model_dump(), + "data_loader_class": get_iterator_state_name(self.data_loader_state), "scheduler": self.scheduler.state_dict(), } def load_state_dict(self, state_dict): self.step = state_dict["step"] self.acc_step = state_dict["acc_step"] - self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) + self.data_loader_class = state_dict["data_loader_class"] + if self.data_loader_class == "multiprocess": + self.data_loader_state = MultiprocessIteratorState( + **state_dict["data_loader_state"] + ) + elif self.data_loader_class == "packing": + self.data_loader_state = PackingIteratorState( + **state_dict["data_loader_state"] + ) + else: + raise ValueError(f"invalid data loader class: {self.data_loader_class}") self.scheduler.load_state_dict(state_dict["scheduler"]) @@ -345,7 +363,10 @@ def train(args: TrainArgs): nwords_since_last_log = 0 time_last_log = timer() gc.collect() - while train_state.step < args.steps: + saved = False + while train_state.step < args.steps and ( + args.max_steps is None or train_state.step < args.max_steps + ): # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 train_state.acc_step += 1 train_state.acc_step = train_state.acc_step % args.grad_acc_steps @@ -552,7 +573,6 @@ def train(args: TrainArgs): f" pow: {gpu_mem_stats.power_draw/1000} W" ) - saved = False if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): @@ -567,18 +587,14 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval - - eval_args = dataclass_from_dict(EvalArgs, args.eval) + eval_args = args.eval eval_args.global_step = train_state.step eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) - eval_args.dump_dir = str( - os.path.join( - args.dump_dir, - "evals", - EVAL_FOLDER_NAME.format(train_state.step), - ) + eval_args.dump_dir = os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), ) eval_args.metric_log_dir = args.dump_dir if args.async_eval_gpus is None: @@ -619,6 +635,9 @@ def train(args: TrainArgs): args, device_mesh=world_mesh, ) + if isinstance(data_loader, MultiprocessIterator): + logger.info("Closing MP iterator before exiting") + data_loader.shutdown() gc.collect() @@ -661,15 +680,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(TrainArgs().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - train_args = TrainArgs.model_validate(cfg) + train_args = parse_args(TrainArgs) if train_args.debug_dynamo: import torch._dynamo