From 7622d28b749b632498f5c9ecddd46e0e099ab1df Mon Sep 17 00:00:00 2001
From: Pedro Rodriguez <par@meta.com>
Date: Mon, 27 Jan 2025 09:46:44 -0800
Subject: [PATCH] Initial codes and scripts for training entropy model (#34)

Summary:

Test Plan:
---
 .gitignore                                    |  1 +
 bytelatent/args.py                            | 13 ++-
 bytelatent/configs/debug.yaml                 |  3 +-
 bytelatent/configs/entropy_model.yaml         | 82 +++++++++++++++++++
 bytelatent/data/data_types.py                 |  2 +-
 bytelatent/data/iterators/packing_iterator.py | 42 ++++++++++
 .../data/iterators/sequence_iterator.py       | 30 +++++--
 bytelatent/data/patcher.py                    | 10 ++-
 bytelatent/model/blt.py                       |  5 +-
 bytelatent/test_blt.py                        |  3 +-
 bytelatent/train.py                           | 52 +++++++++---
 11 files changed, 209 insertions(+), 34 deletions(-)
 create mode 100644 bytelatent/configs/entropy_model.yaml

diff --git a/.gitignore b/.gitignore
index 6c664b8..d1d7c2a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -166,3 +166,4 @@ figures/
 .vscode/
 .DS_Store
 internal/
+jobs_parallel-copy/
diff --git a/bytelatent/args.py b/bytelatent/args.py
index a332c89..56de22d 100644
--- a/bytelatent/args.py
+++ b/bytelatent/args.py
@@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel):
     max_encoder_seq_length: int = 12288
     enable_byte_ngrams: bool = False
 
+    add_patches: bool = True
+
     tokenizer_args: TokenizerArgs = TokenizerArgs()
     patcher_args: PatcherArgs = PatcherArgs()
 
@@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel):
                 looping_iterator,
                 patcher_args=self.patcher_args,
                 tokenizer_args=self.tokenizer_args,
+                add_patches=self.add_patches,
             )
             sequence_iterator = SequenceIterator(
                 preprocess_iterator,
@@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel):
             source_to_iterator=source_to_sequence_iterators,
         )
         tokenizer = self.tokenizer_args.build()
+        if self.tokenizer_args.name == "bytes":
+            # TODO: Check this with Artidoro
+            pad_id = 0
+        else:
+            pad_id = tokenizer.boe_id
         packing_args = PackingArgs(
             batch_size=self.batch_size,
             seq_len=self.seq_len,
-            pad_id=tokenizer.boe_id,
+            pad_id=pad_id,
             max_length=self.max_encoder_seq_length,
             pad_to_max_length=self.pad_to_max_length,
             enable_byte_ngrams=self.enable_byte_ngrams,
+            tokenizer_name=self.tokenizer_args.name,
         )
         packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
         if self.load_async:
@@ -180,7 +189,7 @@ class TrainArgs(BaseModel):
 
     data: DataloaderArgs = DataloaderArgs()
     optim: OptimArgs = OptimArgs()
-    model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
+    model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs()
     # This is only needed for training the entropy model
     entropy_model: LMTransformerArgs | None = None
     # Instead of training main model, train entropy model
diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml
index 4ae4459..1098ff5 100644
--- a/bytelatent/configs/debug.yaml
+++ b/bytelatent/configs/debug.yaml
@@ -26,10 +26,9 @@ model:
   vocab_size: 260
   dim_token: 256
   patch_size: 6
-  tokenization_mode: "bytes"
   patching_mode: "space"
   tie_local_encoder_decoder_logits: false
-  data_loader_patching: true
+  patch_in_forward: false
   max_encoder_seq_length: 12288
   pad_to_max_length: true
   patching_threshold: 3.1439168453216553
diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml
new file mode 100644
index 0000000..51b65d4
--- /dev/null
+++ b/bytelatent/configs/entropy_model.yaml
@@ -0,0 +1,82 @@
+# Template config, need to change dump_dir, data.root_dir and tokenizer.path
+# Evals can be activated by uncommenting its config
+# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
+
+dump_dir: /tmp/
+name: "debug"
+steps: 100_000
+probe_freq: null
+seed: 777
+optim:
+  lr: 4e-04
+  warmup: 500
+  lr_min_ratio: 0.1
+  clip: 10.0
+
+distributed:
+  fsdp_type: full_shard
+  model_dtype: bf16
+  matmul_allow_tf32: false
+  selective_activation_checkpointing: false
+  tp_size: 1
+
+train_entropy_model: true
+model: null
+entropy_model:
+  dim: 768
+  n_layers: 14
+  n_heads: 12
+  max_seqlen: 8192
+  # vocab_size: -1
+  vocab_size: 260
+  ffn_dim_multiplier: 1.0
+  sliding_window: 512
+  attn_bias_type: "local_block_causal"
+  attn_impl: "xformers"
+
+data:
+  s3_profile: blt
+  root_dir: ???
+  sources:
+    dclm_baseline_1.0: 1.0
+  batch_size: 2
+  prefetch_size: 64
+  # seqlen is in terms of patches and
+  # max_encoder_seq_length is in terms of bytes.
+  # For entropy model, these are the same since 1 patch=1 byte
+  seq_len: 8192
+  max_encoder_seq_length: 8192
+  load_async: true
+  preprocess_dir: ???
+  # We don't need patches for this model
+  add_patches: false
+  patcher_args:
+    # This doesn't matter since byte entropy model doesn't use patching,
+    # so pick the most efficient, so static
+    patching_mode: byte
+  tokenizer_args:
+    name: bytes
+
+profiling:
+  run: false
+
+checkpoint:
+  dump:
+    every: 500
+    keep: 3
+  eval:
+    every: 1000
+    keep: -1
+
+logging:
+  freq: 10
+
+eval_on_gpus: 8
+eval:
+  dataset_dir: ???
+  tasks: ???
+  generator:
+    max_tokens: 65536
+    dtype: bf16
+
+  mp_size: 1
diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py
index 7e142e4..aa2daa9 100644
--- a/bytelatent/data/data_types.py
+++ b/bytelatent/data/data_types.py
@@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
 class BltSequence(BaseModel):
     tokens: list[int]
     mask: list[bool]
-    patch_lengths: list[int]
+    patch_lengths: list[int] | None
 
 
 @dataclass
diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py
index 361fc03..fa29149 100644
--- a/bytelatent/data/iterators/packing_iterator.py
+++ b/bytelatent/data/iterators/packing_iterator.py
@@ -17,6 +17,7 @@ class PackingArgs(BaseModel):
     max_length: int | None
     pad_to_max_length: bool
     enable_byte_ngrams: bool
+    tokenizer_name: str
 
 
 class PackingIteratorState(BaseModel, IteratorState):
@@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
         )
 
     def create_iter(self):
+        if self.packing_args.tokenizer_name == "bytes":
+            return self._create_iter_from_bytes()
+        else:
+            return self._create_iter_from_patch_lengths()
+
+    def _create_iter_from_bytes(self):
+        sequence_iter = self.sequence_iterator.create_iter()
+        batch_size = self.packing_args.batch_size
+        pad_id = self.packing_args.pad_id
+        seq_len = self.packing_args.seq_len
+        while True:
+            tokens: list[list[int]] = []
+            masks: list[list[bool]] = []
+
+            for _ in range(self.packing_args.batch_size):
+                sequence = next(sequence_iter)
+                _tokens = sequence.tokens
+                _mask = sequence.mask
+                assert (
+                    sequence.patch_lengths is None
+                ), "patch_lengths should not be used in byte packing"
+                tokens.append(_tokens)
+                masks.append(_mask)
+
+            x = np.full((batch_size, seq_len), fill_value=pad_id)
+            y = np.full((batch_size, seq_len), fill_value=pad_id)
+
+            for i, tok_seq in enumerate(tokens):
+                x[i, : len(tok_seq)] = tok_seq
+                y[i, : len(tok_seq) - 1] = tok_seq[1:]
+            batch = Batch(x=x, y=y)
+            assert (
+                batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
+            ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
+            yield batch
+
+    def _create_iter_from_patch_lengths(self):
         sequence_iter = self.sequence_iterator.create_iter()
         batch_size = self.packing_args.batch_size
         pad_id = self.packing_args.pad_id
@@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
                 _tokens = sequence.tokens
                 _mask = sequence.mask
                 _patch_lengths = sequence.patch_lengths
+                assert (
+                    _patch_lengths is not None
+                ), "patch lengths are required for packing based on patches."
+                # Reminder: seq_len is in terms of patches
                 assert len(sequence.patch_lengths) == self.packing_args.seq_len
                 last_patch_length = 0
                 if _patch_lengths[0] > 1:
diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py
index 14e3747..d90ea31 100644
--- a/bytelatent/data/iterators/sequence_iterator.py
+++ b/bytelatent/data/iterators/sequence_iterator.py
@@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator):
         for example in example_iter:
             assert example.tokens is not None
             assert example.mask is not None
-            assert example.patch_lengths is not None
+            if self.preprocess_iterator.add_patches:
+                assert example.patch_lengths is not None
+                assert len(example.tokens) == sum(example.patch_lengths)
+            else:
+                assert example.patch_lengths is None
             assert len(example.tokens) != 0
             assert len(example.mask) != 0
             assert len(example.tokens) == len(example.mask)
-            assert len(example.tokens) == sum(example.patch_lengths)
 
             tokens.extend(example.tokens)
             mask.extend(example.mask)
-            patch_lengths.extend(example.patch_lengths)
+            if self.preprocess_iterator.add_patches:
+                patch_lengths.extend(example.patch_lengths)
+            else:
+                # This lets the rest of the code work as expected and just yield byte seqs
+                patch_lengths.extend([1] * len(example.tokens))
 
             while len(patch_lengths) >= n_buffer_patches:
                 if first:
@@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator):
                         == len(seq_mask[idx])
                     ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
                     assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
-                    yield BltSequence(
-                        tokens=seq_tokens[idx],
-                        mask=seq_mask[idx],
-                        patch_lengths=seq_patch_lengths[idx],
-                    )
+                    if self.preprocess_iterator.add_patches:
+                        yield BltSequence(
+                            tokens=seq_tokens[idx],
+                            mask=seq_mask[idx],
+                            patch_lengths=seq_patch_lengths[idx],
+                        )
+                    else:
+                        yield BltSequence(
+                            tokens=seq_tokens[idx],
+                            mask=seq_mask[idx],
+                            patch_lengths=None,
+                        )
diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py
index afcfa2e..44ff5e9 100644
--- a/bytelatent/data/patcher.py
+++ b/bytelatent/data/patcher.py
@@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum):
     bpe = "bpe"
     bpe_patcher = "bpe_patcher"
     space = "space"
+    static = "static"
+    byte = "byte"
 
 
 class PatcherArgs(BaseModel):
@@ -34,7 +36,6 @@ class PatcherArgs(BaseModel):
     max_patch_length: int | None = None
     patch_size: float = 4.5
     patching_batch_size: int = 1
-    data_loader_patching: bool = False
     device: str = "cuda"
     monotonicity: bool = False
     log_time: bool = False
@@ -486,7 +487,6 @@ class Patcher:
         self.max_patch_length = patcher_args.max_patch_length
         self.patch_size = patcher_args.patch_size
         self.patching_batch_size = patcher_args.patching_batch_size
-        self.data_loader_patching = patcher_args.data_loader_patching
         self.device = patcher_args.device
         self.monotonicity = patcher_args.monotonicity
         self.log_time = patcher_args.log_time
@@ -528,7 +528,7 @@ class Patcher:
         seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
         scores = None
         # STATIC
-        if self.patching_mode is None:
+        if self.patching_mode == PatchingModeEnum.static:
             patch_lengths = torch.zeros(
                 (bs, math.ceil(seq_len_next_tok / self.patch_size)),
                 dtype=tokens.dtype,
@@ -536,6 +536,10 @@ class Patcher:
             ).fill_(self.patch_size)
             if seq_len_next_tok % self.patch_size != 0:
                 patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
+        elif self.patching_mode == PatchingModeEnum.byte:
+            patch_lengths = torch.ones(
+                (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device
+            )
         # ENTROPY
         elif self.patching_mode == PatchingModeEnum.entropy:
             if self.log_time:
diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py
index 843ad34..a62be23 100644
--- a/bytelatent/model/blt.py
+++ b/bytelatent/model/blt.py
@@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
     n_heads: int = 8
     # TODO: What is the purpose of this parameter?
     weight_tying: bool = False
+    patch_in_forward: bool = False
 
     # Architecture and dimensions
     dim_token: int = 256
@@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
     n_layers_local_encoder: int = 8
 
     # Tokenization and patching
-    tokenization_mode: str = "bpe"
     patch_size: float | None = None
     patching_mode: str | None = None
     patching_threshold: float | None = None
@@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
     monotonicity: bool = False
     patching_batch_size: int = 1
     patching_device: str = "cuda"
-    data_loader_patching: bool = False
     max_patch_length: int | None = None
 
     # Encoder/Decoder configuration
@@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module):
             self.output.weight = self.tok_embeddings.weight
 
         # Patcher module
-        if not args.data_loader_patching:
+        if args.patch_in_forward:
             self.patcher = Patcher(
                 PatcherArgs(
                     patch_size=args.patch_size,
diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py
index 36a9882..eb94df3 100644
--- a/bytelatent/test_blt.py
+++ b/bytelatent/test_blt.py
@@ -68,10 +68,9 @@ def create_args(cross_attention=False):
         # Additional args from command line
         dim_token=256,
         patch_size=6,
-        tokenization_mode="bytes",
         patching_mode="space",
         tie_local_encoder_decoder_logits=False,
-        data_loader_patching=True,
+        patch_in_forward=False,
         max_encoder_seq_length=12288,
         pad_to_max_length=True,
         encoder_lm_loss=False,
diff --git a/bytelatent/train.py b/bytelatent/train.py
index 80bd393..1d0fa40 100644
--- a/bytelatent/train.py
+++ b/bytelatent/train.py
@@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD
 from bytelatent.profiling import maybe_run_profiler
 from bytelatent.stool import StoolArgs, launch_job
 from bytelatent.transformer import (
+    LMTransformer,
     build_fsdp_grouping_plan,
     get_no_recompute_ops,
     get_num_flop_per_token,
@@ -103,10 +104,15 @@ class TrainState(Stateful):
 
 
 def validate_train_args(args: TrainArgs, output_size: int):
-    if args.model.vocab_size < 0:
+    assert args.model is not None or args.entropy_model is not None
+    if args.model is not None:
         logger.info(f"Setting model output size to {args.model.vocab_size}")
         args.model.vocab_size = output_size
 
+    if args.entropy_model is not None:
+        logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
+        args.entropy_model.vocab_size = output_size
+
     assert args.dump_dir, "Dump dir not set"
 
     if args.checkpoint.path is None:
@@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int):
                 and args.distributed.dp_replicate == get_world_size()
             )
 
-    args.model.max_seqlen = args.data.seq_len
+    if args.model is not None:
+        args.model.max_seqlen = args.data.seq_len
+    if args.entropy_model is not None:
+        args.entropy_model.max_seqlen = args.data.seq_len
 
     if args.distributed.tp_size == 1:
         logger.warning(
@@ -237,7 +246,14 @@ def train(args: TrainArgs):
 
         # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
         with torch.device("meta"):
-            model = ByteLatentTransformer(args.model)
+            if args.train_entropy_model:
+                assert args.entropy_model is not None
+                model = LMTransformer(args.entropy_model)
+                model_args = args.entropy_model
+            else:
+                assert args.model is not None
+                model = ByteLatentTransformer(args.model)
+                model_args = args.model
         logger.info("Model is built !")
 
         model_param_count = get_num_params(model)
@@ -247,7 +263,7 @@ def train(args: TrainArgs):
             world_mesh,
             args.model,
             args.distributed,
-            fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
+            fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
             tp_parallelize=tp_parallelize,
             no_recompute_ops=get_no_recompute_ops(),
         )
@@ -267,7 +283,7 @@ def train(args: TrainArgs):
             model.rope_embeddings.reset_parameters()  # For RoPe initialization since it's a buffer it might not be loaded
         else:
             with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
-                torch.manual_seed(args.model.seed)
+                torch.manual_seed(model_args.seed)
                 model.init_weights()
         check_model_value_range(model, range=10.0, std=1.0)
 
@@ -342,10 +358,17 @@ def train(args: TrainArgs):
                 batch.x,
             ).cuda()
             batch_y = torch.from_numpy(batch.y).cuda()
-            batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
+            if batch.patch_lengths is None:
+                batch_patch_lengths = None
+            else:
+                batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
             mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
 
-            if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None:
+            if (
+                not args.train_entropy_model
+                and args.model.encoder_enable_byte_ngrams
+                and batch.ngram_ids is None
+            ):
                 raise ValueError(
                     "Cannot enable byte ngrams and have batch.ngram_ids be None"
                 )
@@ -408,9 +431,12 @@ def train(args: TrainArgs):
                     next(probe_mod.parameters()).grad is None
                 ), "Probe model shouldn't have grads at this point"
 
-            pred = model(
-                batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
-            )
+            if args.train_entropy_model:
+                pred = model(batch_x)
+            else:
+                pred = model(
+                    batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
+                )
 
             loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
 
@@ -474,9 +500,9 @@ def train(args: TrainArgs):
                 # Use xformer's analyze profile trace to get actual measurement
                 FLOPS = (
                     get_num_flop_per_token(
-                        model_param_count - args.model.vocab_size * args.model.dim,
-                        args.model.n_layers,
-                        args.model.dim,
+                        model_param_count - model_args.vocab_size * model_args.dim,
+                        model_args.n_layers,
+                        model_args.dim,
                         args.data.seq_len,
                     )
                     * wps