commit bcc039bb75aec70385fd5c148497d4ae86b526b5
Author: Pedro Rodriguez <par@meta.com>
Date:   Thu Dec 12 15:32:30 2024 -0800

    Initial commit

diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml
new file mode 100644
index 0000000..433932f
--- /dev/null
+++ b/.github/workflows/black.yml
@@ -0,0 +1,12 @@
+name: Lint with Black
+
+on: [push, pull_request]
+
+jobs:
+  lint:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v4
+      - uses: psf/black@stable
+        with:
+          version: "24.8.0"
diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml
new file mode 100644
index 0000000..16f1abf
--- /dev/null
+++ b/.github/workflows/isort.yml
@@ -0,0 +1,10 @@
+name: Lint with isort
+
+on: [push, pull_request]
+
+jobs:
+  lint:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v4
+      - uses: isort/isort-action@master
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..56891a9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,168 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+#   This is especially recommended for binary packages to ensure reproducibility, and is more
+#   commonly ignored for libraries.
+#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+#   in version control.
+#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+#  and can be added to the global gitignore or merged into this file.  For a more nuclear
+#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+*.out
+
+figures/
+.vscode/
+.DS_Store
+
diff --git a/.prettierrc b/.prettierrc
new file mode 100644
index 0000000..ae6f541
--- /dev/null
+++ b/.prettierrc
@@ -0,0 +1,8 @@
+{
+  "overrides": [
+    {
+      "files": "*.yaml",
+      "options": { "tabWidth": 2 }
+    }
+  ]
+}
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000..3232ed6
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,80 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+This Code of Conduct also applies outside the project spaces when there is a
+reasonable belief that an individual's behavior may have a negative impact on
+the project or its community.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at <opensource-conduct@meta.com>. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..88d3ab0
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,36 @@
+# Contributing to
+
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Pull Requests
+
+We actively welcome your pull requests.
+
+1. Fork the repo and create your branch from `main`.
+2. If you've added code that should be tested, add tests.
+3. If you've changed APIs, update the documentation.
+4. Ensure the test suite passes.
+5. Make sure your code lints.
+6. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Contributor License Agreement ("CLA")
+
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Meta's open source projects.
+
+Complete your CLA here: <https://code.facebook.com/cla>
+
+## Issues
+
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
+disclosure of security bugs. In those cases, please go through the process
+outlined on that page and do not file a public issue.
+
+## License
+
+By contributing to BLT, you agree that your contributions will be licensed
+under the LICENSE file in the root directory of this source tree.
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..bc559a9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,28 @@
+BSD 3-Clause License
+
+Copyright 2024 Meta
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice,this list
+of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice, this
+list of conditions and the following disclaimer in the documentation
+and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its contributors may
+be used to endorse or promote products derived from this software without specific
+prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
+EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
+OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
+SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
+BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
+DAMAGE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..689c8e4
--- /dev/null
+++ b/README.md
@@ -0,0 +1,117 @@
+# Byte Latent Transformer
+
+This repository contains code for our paper: "Byte Latent Transformer: Patches Scale Better Than Tokens"
+
+- [Paper Link](https://dl.fbaipublicfiles.com/blt/BLT__Patches_Scale_Better_Than_Tokens.pdf)
+
+## Abstract
+
+We introduce the Byte Latent Transformer architecture (BLTs), a new byte-level LLM architecture that
+for the first time, matches tokenization-based LLM performance at scale, with significant improvements
+in inference efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve
+as the primary units of computation. Patches are segmented dynamically based on the entropy of the
+next byte, allocating more compute and model capacity where there is more data complexity. The BLT
+architecture includes new attention mechanisms to maximize the information flow between byte and
+patch hidden representations and a new type of byte-sequence memory. We present the first scaling
+study of byte-level models up to 8B parameters and 8T training bytes, showing for the first time
+that we can train a model end-to-end at scale from bytes with no tokenization or other preprocessing.
+Scaling trends reveal training and inference efficiency benefits from dynamically selecting very long
+patches on average, along with qualitative improvements with reasoning and long tail generalization
+from modeling byte-sequences.
+
+![BLT Architecture Diagram](blt-figure.jpg)
+
+## Development Status
+
+We are actively updating the blt code to make it easier to reproduce our results.
+Please file an issue and/or be patient while we make more of our code public!
+
+## Quick start
+
+The following commands launch a SLURM job that creates an environment for Meta Lingua.
+The env creation should take around 5 minutes without counting downloads.
+
+```bash
+git clone https://github.com/facebookresearch/blt
+cd blt
+
+bash setup/create_env.sh
+# or if you have access to a SLURM cluster
+sbatch setup/create_env.sh
+```
+
+Once that is done your can activate the environment
+
+```bash
+conda activate blt_<date>
+```
+
+use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
+This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
+
+```bash
+python setup/download_prepare_hf_data.py fineweb_edu <MEMORY> --data_dir ./data --seed 42 --nchunks <NCHUNKS>
+```
+
+to download tokenizer (here llama3), use the folowing script:
+
+```bash
+python setup/download_tokenizer.py llama3 <SAVE_PATH> --api_key <HUGGINGFACE_TOKEN>
+```
+
+Now launch a debug job to check if everything works. **The provided configurations are templates, you need to adapt them for them to work (change `dump_dir`, `data.root_dir`, `data.tokenizer.path`, etc ...)**
+
+```bash
+# stool stands for SLURM tool !
+python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition=<partition>
+# if you want to launch locally you can use torchrun
+torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
+# or you can also launch on 1 GPU
+python -m bytelatent.train  config=apps/bytelatent/configs/debug.yaml
+```
+
+When using `stool`, if a job crashes, it can be relaunched using sbatch:
+
+```bash
+sbatch path/to/dump_dir/submit.slurm
+```
+
+## Linting
+
+To lint, run the following command
+
+```
+bash dev/lint.sh
+```
+
+## Citation
+
+The BLT is partially based on Meta Lingua, so consider citing it in addition to our BLT paper if you re-use our work.
+
+BLT Paper Citation (will be updated to arXiv soon)
+
+```
+@article{meta_blt,
+  author = {Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer},
+  title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
+  url = {https://github.com/facebookresearch/blt},
+  year = {2024}
+}
+```
+
+Lingua Code
+
+```
+@misc{meta_lingua,
+  author = {Mathurin Videau, Badr Youbi Idrissi, Daniel Haziza, Luca Wehrstedt, Jade Copet, Olivier Teytaud, David Lopez-Paz},
+  title = {{Meta Lingua}: A minimal {PyTorch LLM} training library},
+  url = {https://github.com/facebookresearch/lingua},
+  year = {2024}
+}
+```
+
+## License
+
+The BLT code is partially based on Meta Lingia.
+
+Meta Lingua is licensed under BSD-3-Clause license. Refer to the LICENSE file in the top level directory.
diff --git a/apps/__init__.py b/apps/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/apps/main/__init__.py b/apps/main/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/apps/main/configs/eval.yaml b/apps/main/configs/eval.yaml
new file mode 100644
index 0000000..4b52ba0
--- /dev/null
+++ b/apps/main/configs/eval.yaml
@@ -0,0 +1,35 @@
+name: "debug_evals"
+# ckpt_dir: !!CHANGETHIS!!
+# dump_dir: !!CHANGETHIS!!
+generator:
+  max_tokens: 8192
+  dtype: bf16
+  temperature: 1.0
+  top_p: 0.95
+harness:
+  tasks:
+    - hellaswag
+    - task: boolq
+      dataset_kwargs:
+        trust_remote_code: true
+    - task: nq_open
+      num_fewshot: 5
+    - piqa
+    - task: social_iqa
+      dataset_kwargs:
+        trust_remote_code: true
+    - triviaqa
+    - winogrande
+    - openbookqa
+    - arc_easy
+    - arc_challenge
+    - race
+    - commonsense_qa
+    # - coqa
+    - copa
+    - gsm8k
+    - bbh
+    - mmlu
+    - mmlu_pro
+validation:
+  max_steps: 1000
diff --git a/apps/main/configs/llama_1B.yaml b/apps/main/configs/llama_1B.yaml
new file mode 100644
index 0000000..786d848
--- /dev/null
+++ b/apps/main/configs/llama_1B.yaml
@@ -0,0 +1,87 @@
+# dump_dir: !!!CHANGE_THIS!!!
+name: large_lm
+steps: 60_000
+probe_freq: null
+seed: 777
+
+optim:
+  lr: 3e-3
+  weight_decay: 0.033
+  warmup: 5000
+  lr_min_ratio: 0.000001
+  clip: 1.0
+
+distributed:
+  fsdp_type: full_shard
+  compile: true
+  model_dtype: bf16
+  matmul_allow_tf32: false
+  selective_activation_checkpointing: false
+  tp_size: 1
+
+model:
+  dim: 2048
+  n_layers: 25
+  n_heads: 16
+
+data:
+  root_dir: data/shuffled
+  sources:
+    dclm_baseline_1.0: 100.0
+  batch_size: 4
+  prefetch_size: 1024
+  seq_len: 4096
+  n_views: 2
+  load_async: true
+  add_bos: true
+  add_eos: true
+  tokenizer:
+    name: tiktoken
+    path: tokenizers/cl_toplang_128k.tiktoken
+
+profiling:
+  run: true
+  mem_warmup: 0
+  mem_steps: 4
+  profile_warmup: 100
+  profile_steps: 4
+
+checkpoint:
+  dump:
+    every: 2500
+    keep: 3
+  eval:
+    every: 5000
+    keep: -1
+
+logging:
+  freq: 1
+
+async_eval_gpus: 8
+eval:
+  harness:
+    tasks:
+      - hellaswag
+      - task: boolq
+        dataset_kwargs:
+          trust_remote_code: true
+      - piqa
+      - task: social_iqa
+        dataset_kwargs:
+          trust_remote_code: true
+      - winogrande
+      - openbookqa
+      - arc_easy
+      - arc_challenge
+      - race
+      - commonsense_qa
+      - copa
+      # - coqa
+      # - task: nq_open
+      #   num_fewshot: 5
+      # - triviaqa
+  validation:
+    max_steps: 1000
+  generator:
+    max_tokens: 16384
+    dtype: bf16
diff --git a/apps/main/configs/llama_7B.yaml b/apps/main/configs/llama_7B.yaml
new file mode 100644
index 0000000..4461dd9
--- /dev/null
+++ b/apps/main/configs/llama_7B.yaml
@@ -0,0 +1,95 @@
+#python -m lingua.stool config=apps/main/configs/llama2_7B.yaml nodes=32 account=fair_amaia_cw_codegen qos=lowest
+# dump_dir: !!!CHANGE_THIS!!!
+name: "7b_baseline"
+steps: 100_000
+grad_acc_steps: 1
+probe_freq: 100
+
+seed: 777
+optim:
+  lr: 1.0e-3
+  weight_decay: 0.1
+  warmup: 2000
+  lr_min_ratio: 0.000001
+  clip: 1.0
+
+distributed:
+  fsdp_type: full_shard
+  compile: true
+  model_dtype: bf16
+  matmul_allow_tf32: false
+  selective_activation_checkpointing: false
+  tp_size: 1
+
+model:
+  dim: 4096
+  n_layers: 32
+  n_heads: 32
+  rope_theta: 100_000
+  ffn_dim_multiplier: 1.0
+  multiple_of: 256
+
+data:
+  root_dir: data/shuffled
+  sources:
+    dclm_baseline_1.0: 1.0
+  batch_size: 2
+  prefetch_size: 1024
+  seq_len: 4096
+  n_views: 2
+  load_async: true
+  tokenizer:
+    name: tiktoken
+    path: tokenizers/cl_toplang_128k.tiktoken
+
+profiling:
+  run: true
+  mem_warmup: 0
+  mem_steps: 4
+  profile_warmup: 100
+  profile_steps: 4
+
+checkpoint:
+  dump:
+    every: 10000
+    keep: -1
+  eval:
+    every: 1000
+    keep: 3
+
+logging:
+  freq: 1
+
+async_eval_gpus: 8
+eval:
+  dataset_dir: datasets/eval
+  harness:
+    tasks:
+      - hellaswag
+      - task: boolq
+        dataset_kwargs:
+          trust_remote_code: true
+      - piqa
+      - task: social_iqa
+        dataset_kwargs:
+          trust_remote_code: true
+      - winogrande
+      - openbookqa
+      - arc_easy
+      - arc_challenge
+      - race
+      - commonsense_qa
+      # - coqa
+      - copa
+      - mmlu
+      - mmlu_pro
+      # - task: nq_open
+      #   num_fewshot: 5
+      # - triviaqa
+      # - gsm8k
+      # - bbh
+  validation:
+    max_steps: 1000
+  generator:
+    max_tokens: 8192
+    dtype: bf16
diff --git a/apps/main/eval.py b/apps/main/eval.py
new file mode 100644
index 0000000..ed20f49
--- /dev/null
+++ b/apps/main/eval.py
@@ -0,0 +1,354 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import json
+import logging
+import os
+from collections import defaultdict
+from dataclasses import asdict, dataclass, field
+from datetime import datetime
+from pathlib import Path
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+from lingua.args import dump_config
+from lingua.data import init_choice_state, setup_sources
+from lm_eval import simple_evaluate
+from lm_eval.api.instance import Instance
+from lm_eval.api.model import LM
+from omegaconf import OmegaConf
+
+from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
+from bytelatent.distributed import (
+    DistributedArgs,
+    dist_mean_dict,
+    get_global_rank,
+    get_world_size,
+    setup_torch_distributed,
+)
+from bytelatent.transformer import LMTransformer, LMTransformerArgs
+
+from apps.main.generate import (
+    PackedCausalTransformerGenerator,
+    PackedCausalTransformerGeneratorArgs,
+    load_consolidated_model_and_tokenizer,
+)
+
+EVAL_FOLDER_NAME = "{:010d}"
+
+logger = logging.getLogger()
+
+
+@dataclass
+class LMHarnessArgs:
+    tasks: Optional[List[Any]] = None
+    num_fewshot: Optional[int] = None
+    device: Optional[str] = None
+    use_cache: Optional[str] = None
+    cache_requests: bool = False
+    rewrite_requests_cache: bool = False
+    delete_requests_cache: bool = False
+    limit: Optional[Union[int, float]] = None
+    bootstrap_iters: int = 100000
+    check_integrity: bool = False
+    write_out: bool = False
+    log_samples: bool = True
+    system_instruction: Optional[str] = None
+    apply_chat_template: Union[bool, str] = False
+    fewshot_as_multiturn: bool = False
+    gen_kwargs: Optional[str] = None
+    verbosity: str = "INFO"
+    predict_only: bool = False
+    random_seed: int = 0
+    numpy_random_seed: int = 1234
+    torch_random_seed: int = 1234
+    fewshot_random_seed: int = 1234
+
+
+@dataclass
+class ValidationArgs:
+    max_steps: Optional[int] = (
+        None  # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
+    )
+    use_val_from_train_src: bool = True  # Use the validation set from training sources
+    root_dir: str = ""
+    sources: List[str] = field(default_factory=list)  # Other sources to eval on
+
+
+@dataclass
+class EvalArgs:
+    name: str = "evals"
+    dump_dir: Optional[str] = None
+    metric_log_dir: Optional[str] = None
+    ckpt_dir: str = ""
+    generator: PackedCausalTransformerGeneratorArgs = field(
+        default_factory=PackedCausalTransformerGeneratorArgs
+    )
+    harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
+    validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
+
+    wandb: Optional[Any] = None
+
+    global_step: Optional[int] = None  # for in-training evaluation
+
+
+def all_dicts_same(dict_list):
+    if not dict_list:  # Check if the list is empty
+        return True
+
+    # Compare each dictionary to the first one
+    first_dict = dict_list[0]
+    return all(d == first_dict for d in dict_list)
+
+
+class MockAccelerator:
+    def gather(self, tensor):
+        l = [torch.zeros_like(tensor) for _ in range(get_world_size())]
+        torch.distributed.all_gather(l, tensor)
+        return torch.stack(l)
+
+    def wait_for_everyone(self):
+        torch.distributed.barrier()
+
+
+# Light wrapper around generator for lm-eval harness
+class EvalHarnessLM(LM):
+    def __init__(self, generator):
+        super().__init__()
+        self.generator = generator
+        self.accelerator = MockAccelerator()
+        self._rank = get_global_rank()
+        self._world_size = get_world_size()
+        self.device = generator.device
+
+    def generate_until(self, requests: List[Instance]) -> List[str]:
+        prompts, gen_args = zip(*[req.args for req in requests])
+        assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
+        gen_args = gen_args[0]
+        temperature = gen_args.get("temperature", 0.0)
+        top_p = gen_args.get("top_p", None)
+        top_k = gen_args.get("top_k", None)
+        until = gen_args.get("until", [])
+
+        self.generator.temperature = temperature
+        self.generator.top_p = top_p
+        self.generator.top_k = top_k
+        self.generator.until = until
+        generations, _, _ = self.generator.generate(prompts)
+        filtered_gen = []
+        for g in generations:
+            for e in until:
+                g = g.replace(e, "")
+            filtered_gen.append(g)
+        return filtered_gen
+
+    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
+        prompts, continuations = zip(*[req.args for req in requests])
+        inputs = [req.args[0] + req.args[1] for req in requests]
+        max_gen_len = self.generator.max_gen_len
+        # We temporarily lower max gen len
+        self.generator.max_gen_len = 1
+        _, lls, greedy = self.generator.generate(inputs)
+        results = []
+        for p, ll, gr in zip(prompts, lls, greedy):
+            p_len = len(
+                self.generator.tokenizer.encode(p, add_bos=False, add_eos=False)
+            )
+            results.append((ll[p_len:].sum().item(), gr[p_len:].all().item()))
+
+        self.generator.max_gen_len = max_gen_len
+        return results
+
+    def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
+        prompts = [req.args[0] for req in requests]
+        max_gen_len = self.generator.max_gen_len
+        # We temporarily lower max gen len
+        self.generator.max_gen_len = 1
+        _, lls, _ = self.generator.generate(prompts)
+        results = []
+        for ll in lls:
+            results.append((ll.sum().item(),))
+        self.generator.max_gen_len = max_gen_len
+
+        return results
+
+
+def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
+    srcs = {}
+    for src in val_args.sources:
+        path = os.path.join(val_args.root_dir, src)
+        srcs[path] = 1.0
+    for src in train_cfg.data.sources:
+        path = os.path.join(train_cfg.data.root_dir, src)
+        srcs[path] = 1.0
+
+    multi_state = init_choice_state(
+        "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
+    )
+    path_to_iter = setup_sources(multi_state)
+
+    max_gen_len = generator.max_gen_len
+    # We temporarily lower max gen len
+    generator.max_gen_len = 1
+
+    all_val_metrics = {}
+    for src in path_to_iter:
+        jsonl_iterator = path_to_iter[src]
+        texts = []
+        logger.info(f"Running validation on {src}...")
+        for step, (content, state) in enumerate(jsonl_iterator):
+            if state["current_iter"] > 0 or (
+                val_args.max_steps is not None and step >= val_args.max_steps
+            ):
+                break
+            content_key = "text" if ("text" in content) else "content"
+            texts.append(content[content_key])
+
+        _, loglikelihood, _ = generator.generate(texts)
+
+        metrics = defaultdict(list)
+        for i, ll in enumerate(loglikelihood):
+            tmp = ll.sum().item()
+            metrics["nll"].append(tmp)
+            metrics["nll_per_token"].append(tmp / len(ll))
+            metrics["nll_per_char"].append(tmp / len(texts[i]))
+
+            metrics["avg_seqlen"].append(len(ll))
+
+        for m in metrics:
+            metrics[m] = sum(metrics[m]) / len(metrics[m])
+        metrics.update(dist_mean_dict(metrics))
+        logger.info(f"Validation on {src} done. Metrics: {metrics}")
+
+        name = os.path.basename(src)
+        if name in all_val_metrics:
+            logger.warning(
+                f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
+            )
+            name = f"{name}_1"
+        all_val_metrics[name] = metrics
+
+    generator.max_gen_len = max_gen_len
+
+    return all_val_metrics
+
+
+def launch_eval(cfg: EvalArgs):
+    if not torch.distributed.is_initialized():
+        setup_torch_distributed(DistributedArgs())
+    if (
+        Path(cfg.ckpt_dir).exists()
+        and (Path(cfg.ckpt_dir) / "params.json").exists()
+        and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
+    ):
+        consolidate_path = Path(cfg.ckpt_dir)
+    else:
+        consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
+        if not consolidate_path.exists() and get_global_rank() == 0:
+            consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
+
+    Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
+    dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
+
+    consolidate_path = str(consolidate_path)
+    torch.distributed.barrier()
+    logger.info("Loading model")
+    model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
+        consolidate_path,
+        model_cls=LMTransformer,
+        model_args_cls=LMTransformerArgs,
+    )
+    logger.info("Model loaded")
+    model.eval()
+    generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
+
+    wrap = EvalHarnessLM(generator)
+    results = simple_evaluate(wrap, **asdict(cfg.harness))
+    val_results = None
+    if cfg.validation:
+        val_results = eval_on_val(generator, cfg.validation, train_cfg)
+    if get_global_rank() == 0:
+        with open(Path(cfg.dump_dir) / "results.json", "w") as f:
+            f.write(json.dumps(results))
+        logger.info(f"All evaluation results: {results['results']}")
+        if val_results is not None:
+            with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
+                f.write(json.dumps(val_results))
+            logger.info(f"All validation results: {val_results}")
+    if cfg.metric_log_dir and get_global_rank() == 0:
+        metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
+
+        logger.info(f"Writing metric logs to {metric_log_path}")
+        timestamp = {
+            "created_at": datetime.utcnow().isoformat(),
+        }
+        if cfg.global_step is not None:
+            timestamp["global_step"] = cfg.global_step
+        print(
+            json.dumps(timestamp | results["results"]),
+            file=open(metric_log_path, mode="a"),
+            flush=True,
+        )
+
+        val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
+        if val_results is not None:
+            print(
+                json.dumps(timestamp | val_results),
+                file=open(val_log_path, mode="a"),
+                flush=True,
+            )
+
+    del generator
+
+
+def main():
+    """
+    The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
+    This accepts arguments as a dot list
+    So if the dataclass looks like
+
+    @dataclass
+    class DummyArgs:
+        name: str
+        model: LMTransformerArgsgs
+
+    @dataclass
+    class LMTransformerArgsgs:
+        dim: int
+
+    Then you can pass model.dim=32 to change values in LMTransformerArgsgs
+    or just name=tictac for top level attributes.
+
+    The behavior here is as follows:
+    1. We instantiate EvalArgs with its default values
+    2. We override those default values with the ones in the provided config file
+    3. We override the result with the additional arguments provided through command line
+
+    For example, if the config is the following
+
+    model:
+        dim: 128
+        n_layers: 4
+
+    and you call eval.py with eval.py model.dim=64
+
+    Then the final TrainArgs will have
+
+    model:
+        dim: 64
+        n_layers: 4
+
+    Plus all the default values in EvalArgs dataclass.
+    """
+    cli_args = OmegaConf.from_cli()
+    file_cfg = OmegaConf.load(cli_args.config)
+    # We remove 'config' attribute from config as the underlying DataClass does not have it
+    del cli_args.config
+
+    default_cfg = OmegaConf.structured(EvalArgs())
+    cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
+    cfg = OmegaConf.to_object(cfg)
+    launch_eval(cfg)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/apps/main/generate.py b/apps/main/generate.py
new file mode 100644
index 0000000..a3a8627
--- /dev/null
+++ b/apps/main/generate.py
@@ -0,0 +1,463 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import time
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from lingua.args import dataclass_from_dict
+from lingua.tokenizers.abstract_tokenizer import Tokenizer
+from lingua.tokenizers.build_tokenizer import build_tokenizer
+from omegaconf import OmegaConf
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.attention.flex_attention import create_block_mask
+from tqdm import tqdm
+
+from bytelatent.base_transformer import (
+    Attention,
+    causal_mask,
+    generate_doc_mask_mod,
+    lengths_to_local_ids,
+    lengths_to_start_ids,
+)
+from bytelatent.checkpoint import CONSOLIDATE_NAME
+from bytelatent.transformer import LMTransformer, LMTransformerArgs
+
+
+def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > p
+    probs_sort[mask] = 0.0
+    next_token = torch.multinomial(probs_sort, num_samples=1)
+    next_token = torch.gather(probs_idx, -1, next_token)
+    return next_token
+
+
+def sample_top_k(probs, k):
+    topk_value, _ = torch.topk(probs, k)  # batch_sz x topk
+    min_value_top_k = topk_value[:, [-1]]
+    probs[probs < min_value_top_k] = 0.0
+    probs.div_(probs.sum(dim=-1, keepdim=True))
+    next_token = torch.multinomial(probs, num_samples=1)
+    return next_token
+
+
+def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
+    shape = logits.shape
+    logits = logits.flatten(end_dim=-2)
+    if temperature > 0.0:
+        probs = torch.softmax(logits / temperature, dim=-1)
+
+        if top_p is not None:
+            next_token = sample_top_p(probs, top_p)
+        elif top_k is not None:
+            next_token = sample_top_k(probs, top_k)
+        else:
+            next_token = torch.multinomial(probs, num_samples=1)
+    else:
+        next_token = torch.argmax(logits, dim=-1)
+    return next_token.view(shape[:-1])
+
+
+def pack_prompts(prompts: List[int]):
+    res = []
+    lengths = []
+    for i, p in enumerate(prompts):
+        p = torch.tensor(p, dtype=torch.long)
+        l = p.size(0)
+        res.append(p)
+        lengths.append(l)
+    lengths = torch.tensor(lengths, dtype=torch.long)
+    res = torch.cat(res)
+    return res, lengths
+
+
+def batch_prompts(prompts, max_elements, lengths=None):
+    batches = []
+    current_batch = []
+    current_count = 0
+
+    for i in range(len(prompts)):
+        prt = prompts[i]
+        prompt_size = len(prt) if lengths is None else lengths[i]
+        if current_count + prompt_size <= max_elements:
+            current_batch.append(prt)
+            current_count += prompt_size
+        else:
+            if current_batch:  # Add the current batch to batches
+                batches.append(current_batch)
+            # Start a new batch with the current prompt
+            current_batch = [prt]
+            current_count = prompt_size
+
+    # Add the last batch if it contains any prompts
+    if current_batch:
+        batches.append(current_batch)
+
+    return batches
+
+
+class KVCache(nn.Module):
+    def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device):
+        super().__init__()
+        shape = (bsz, seqlen, n_heads, head_dim)
+        self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device))
+        self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device))
+        self.offset = 0
+
+    def reset(self):
+        self.k_cache.zero_()
+        self.v_cache.zero_()
+        self.offset = 0
+
+    def update(self, k_val, v_val, tok_idx):
+        # input_pos: [B], k_val: [B, S, H, D]
+        self.k_cache.index_copy_(1, self.offset + tok_idx, k_val)
+        self.v_cache.index_copy_(1, self.offset + tok_idx, v_val)
+        return self.k_cache, self.v_cache
+
+
+@dataclass
+class PackedCausalTransformerGeneratorArgs:
+    temperature: float = 0.0
+    top_p: Optional[float] = None
+    top_k: Optional[float] = None
+    max_gen_len: int = 512  # Maximum number of tokens to generate
+    max_tokens: int = 1024  # Maximum number of tokens that can go through the model
+    max_prompt_len: Optional[int] = None
+    until: List[str] = field(default_factory=list)
+    compile_prefilling: bool = False
+    reduce_generation_overhead: bool = False
+    show_progress: bool = False
+    dtype: Optional[str] = "bf16"
+    device: Optional[str] = "cuda"
+
+
+class PackedCausalTransformerGenerator:
+    def __init__(
+        self,
+        cfg: PackedCausalTransformerGeneratorArgs,
+        model: nn.Module,
+        tokenizer: Tokenizer,
+    ):
+        """
+        This class wraps a causal transformer model with its corresponding tokenizer
+        and provides an efficient way to pack prompts together and do generation on
+        the packed sequence.
+
+        For example, if we had the prompts "Hello, I am a " and "Initiating calibration "
+        Then this class will concatenate those sequence (pack them together)
+        "Hello, I am a Initiating calibration"
+        And make the necessary attention masks such that a sequence only attends to itself
+        during prefilling and generation.
+
+        This class creates a fixed size cache of size max_tokens or sum of prompt sizes
+        + the max number of generated tokens per sequence.
+        """
+        self.model = model
+        self.tokenizer = tokenizer
+        self.temperature = cfg.temperature
+        self.top_p = cfg.top_p
+        self.top_k = cfg.top_k
+
+        self.max_gen_len = cfg.max_gen_len
+        self.max_tokens = cfg.max_tokens
+        self.max_prompt_len = cfg.max_prompt_len
+        self.until = cfg.until
+        self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
+        self.device = cfg.device
+
+        # Compile if necessary
+        self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
+        self.generate_next_token = torch.compile(
+            self.generate_next_token,
+            mode="reduce-overhead",
+            disable=not cfg.reduce_generation_overhead,
+        )
+
+        self.show_progress = cfg.show_progress
+        self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
+
+        self.prefill_doc_id, self.prefill_tok_id = None, None
+        self.padded_doc_id, self.padded_tok_id = None, None
+        self.current_doc_id, self.current_tok_id = None, None
+        self.padded_doc_start = None
+        self.prefill_mask = None
+
+    def clear_cache(self, offset):
+        for module in self.model.modules():
+            if isinstance(module, Attention):
+                if not hasattr(module, "kv_cache"):
+                    module.kv_cache = KVCache(
+                        1,
+                        self.max_tokens,
+                        module.n_kv_heads,
+                        module.head_dim,
+                        self.dtype,
+                        self.device,
+                    )
+                module.kv_cache.offset = offset
+
+    @torch.compiler.disable
+    def setup_prefilling(self, lengths: torch.Tensor):
+        # The KV cache is a fixed size tensor of size max_tokens that we need
+        # to update in order to do correct autoregressive generation.
+
+        # Here we will generate token by token but on multiple sequences
+        # at once. To do so, we need to have an attention mask that makes
+        # each sequence independent.
+
+        # Each sequence will write to its allocated space in the KV Cache.
+        # We allocate len(seq) + max_gen_len to each sequence in the cache.
+
+        # We will generate max_gen_len for each document
+        padded_lengths = lengths + self.max_gen_len
+        max_tokens = self.max_tokens or padded_lengths.sum().item()
+        # The last document might have more padding to fill up to max_tokens
+        padded_lengths[-1] += max_tokens - padded_lengths.sum()
+
+        # This is the start index in the cache for each document
+        self.padded_doc_start = lengths_to_start_ids(padded_lengths)
+        # For example with ab--123--cdef--
+        # this would be 0, 4, 9 if max_gen_len is 2
+
+        # We repeat interleave to align with tokens for prefilling
+        # Ex: ab--123--cdef--
+        #     000044444999999
+        prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths)
+        # This offset will make sure the tokens are written to the
+        # correct positions in the cache during prefilling
+
+        # We either init the cache or clear it by resetting the offset to prefill_offset
+        self.clear_cache(prefill_offset)
+
+        # The prefilling mask looks like the following for
+        # the two packed sequences ab and 123 : ab123
+        # Where spaces are empty cache positions
+        #                 keys
+        #                ab---123---
+        #   queries    a 10000000000
+        #              b 11000000000
+        #              1 00000100000
+        #              2 00000110000
+        #              3 00000111000
+        # We make sure to skip the empty cache positions
+        # and only attend to positions within the same sequence
+        doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths)
+        self.prefill_mask = create_block_mask(
+            doc_mask_mod, 1, None, lengths.sum(), max_tokens
+        )
+
+        # This creates the prefilling token ids which look like
+        # the following for the packed sequence abcdefg1234
+        # abcdefg1234
+        # 01234560123
+        # The token id gives us the position within each sequence
+        # This is used to compute ROPE and to update the cache
+        # At each forward pass the current tokens are written to
+        # offset + tok_id
+        self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths)
+
+        # This creates the padded token and document ids
+        # which look like the following for the packed sequence ab123
+        #               ab---123---               ab---123---
+        # padded_doc_id 00000111111 padded_tok_id 01234012345
+        # This will later be useful for the attention mask at generation
+        self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths)
+
+    @torch.compiler.disable
+    def setup_generation(self, lengths):
+        # KV Cache offset is set to the start of the padded documents
+        for module in self.model.modules():
+            if isinstance(module, Attention):
+                module.kv_cache.offset = self.padded_doc_start
+        # The token ids during generations correspond to the lengths of each doc
+        # current_tok_id will be incremented during generation
+        self.current_tok_id = lengths.clone()
+        # Since we're generating one token per document
+        # the document id is just an arange
+        self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device)
+
+    # From here on some methods for generation
+    def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
+        # Prefilling is done by taking multiple packed sequences and
+        # doing block diagonal attention on them so they remain independent
+        self.setup_prefilling(lengths=lengths)
+        prefill_out = self.model.forward(
+            tokens,
+            tok_idx=self.prefill_tok_id,
+            mask=self.prefill_mask,
+            attn_impl="flex_attention",
+        )
+        self.setup_generation(lengths=lengths)
+        return prefill_out
+
+    def generate_next_token(self, current_token):
+        # Since we're doing generation with multiple sequences at once
+        # we need to ignore tokens and cache entries from other sequences
+        # or in the future.
+        # Example mask :
+        #                  keys
+        #                abc--1234--
+        #   queries    c 11100000000
+        #              4 00000111100
+
+        # mask shape : (n_seqs, cache_size)
+        doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0)
+        caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0)
+        mask = doc_mask & caus_mask
+        out = self.model.forward(
+            current_token,
+            tok_idx=self.current_tok_id,  # n_seqs
+            mask=mask,
+            attn_impl="sdpa",
+        )
+        self.current_tok_id += 1
+        return out
+
+    @torch.inference_mode()
+    def generate(self, prompts):
+        # Tokenize
+        prompts = [
+            self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts
+        ]
+        # Truncate
+        max_seqlen = (
+            self.max_tokens
+            if not hasattr(self.model, "max_seqlen")
+            else self.model.max_seqlen
+        )
+        max_prompt_len = self.max_prompt_len or min(
+            max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len
+        )
+        prompts = [p[-max_prompt_len:] for p in prompts]
+        # Account for the generation in lengths
+        padded_lengths = [len(p) + self.max_gen_len for p in prompts]
+        generation = []
+        loglikelihood = []
+        greedy = []
+        it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths)
+        if self.show_progress:
+            it = tqdm(it)
+        for batch in it:
+            n_seqs = len(batch)
+            generated_tokens = [[] for _ in range(n_seqs)]
+            is_done = [False for _ in range(n_seqs)]
+            packed_batch, lengths = pack_prompts(batch)
+            packed_batch, lengths = packed_batch.cuda(), lengths.cuda()
+            n_seqs = lengths.size(0)
+
+            # Prefilling cache
+            prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths)
+            # Selecting last token in each prompt
+            all_tokens = sample_tokens(
+                prompt_logits, self.temperature, self.top_p, self.top_k
+            )
+            start_token = all_tokens[:, lengths.cumsum(0) - 1]
+
+            for seq_id, tok in enumerate(start_token.squeeze(0).tolist()):
+                generated_tokens[seq_id].append(tok)
+
+            current_token = start_token
+            for i in range(1, self.max_gen_len):
+
+                next_logits = self.generate_next_token(current_token)
+                next_token = sample_tokens(
+                    next_logits.clone(), self.temperature, self.top_p, self.top_k
+                )
+
+                for seq_id, tok in enumerate(next_token.squeeze(0).tolist()):
+                    if not is_done[seq_id]:
+                        generated_tokens[seq_id].append(tok)
+                        current_end_str = self.tokenizer.decode(
+                            generated_tokens[seq_id][-self.max_until_size :]
+                        )
+                        contains_end_string = any(
+                            [e in current_end_str for e in self.until]
+                        )
+                        is_done[seq_id] = (
+                            contains_end_string or tok == self.tokenizer.eos_id
+                        )
+                if all(is_done):
+                    break
+
+                current_token = next_token
+
+            generation.extend([self.tokenizer.decode(g) for g in generated_tokens])
+
+            for p, logit in zip(
+                batch, prompt_logits.squeeze(0).split(lengths.tolist())
+            ):
+                x = logit[:-1]
+                y = torch.tensor(p[1:], device=x.device)
+                loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu())
+                greedy.append((x.argmax(dim=-1) == y).cpu())
+
+        return generation, loglikelihood, greedy
+
+
+def load_consolidated_model_and_tokenizer(
+    consolidated_path,
+    model_cls=LMTransformer,
+    model_args_cls=LMTransformerArgs,
+):
+    ckpt_path = Path(consolidated_path)
+    config = ckpt_path / "params.json"
+    config = OmegaConf.load(config)
+
+    param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
+        config.distributed.model_dtype
+    ]
+    model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
+    tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
+    model = model_cls(model_args)
+    st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
+    model.load_state_dict(st_dict["model"])
+    model = model.cuda().eval()
+    for param in model.parameters():
+        param.data = param.data.to(dtype=param_dtype)
+    return model, tokenizer, config
+
+
+def main():
+    # Load CLI arguments (overrides) and combine with a YAML config
+    cfg = OmegaConf.from_cli()
+    gen_cfg = dataclass_from_dict(
+        PackedCausalTransformerGeneratorArgs, cfg, strict=False
+    )
+    print(cfg)
+
+    model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
+
+    generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
+
+    # Allow multiple prompts
+    prompts = []
+    while True:
+        prompt = input("Enter a prompt (or press enter to finish): ")
+        if not prompt:
+            break
+        prompts.append(prompt)
+
+    # Start generation
+    start_time = time.time()
+    generation, loglikelihood, greedy = generator.generate(prompts)
+    end_time = time.time()
+
+    # Calculate tokens per second
+    total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
+    tokens_per_second = total_tokens / (end_time - start_time)
+
+    # Display the results
+    for i, gen in enumerate(generation):
+        print(f"\nPrompt {i+1}: {prompts[i]}")
+        print(f"Generated Text: {gen}")
+
+    print(f"\nTokens per second: {tokens_per_second:.2f}")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py
new file mode 100644
index 0000000..bdb47da
--- /dev/null
+++ b/apps/main/lingua_train.py
@@ -0,0 +1,654 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import gc
+import logging
+import os
+import sys
+from contextlib import ExitStack
+from copy import deepcopy
+from dataclasses import asdict, dataclass, field
+from pathlib import Path
+from timeit import default_timer as timer
+from typing import Any, Dict, Optional
+
+import torch
+import torch.distributed
+import wandb
+import xformers.profiler
+from lingua.args import dataclass_from_dict, dump_config, flatten_dict
+from lingua.data import (
+    DataArgs,
+    PackTokensState,
+    build_dataloader_from_args,
+    init_dataloader_state_from_args,
+)
+from lingua.tokenizers.build_tokenizer import TokenizerArgs
+from omegaconf import OmegaConf
+from pydantic import BaseModel
+from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint.stateful import Stateful
+from torch.optim import lr_scheduler
+
+from bytelatent.checkpoint import (
+    CheckpointArgs,
+    CheckpointManager,
+    load_from_checkpoint,
+)
+from bytelatent.distributed import (
+    DistributedArgs,
+    EnvironmentArgs,
+    check_model_value_range,
+    clean_env,
+    dist_mean_dict,
+    get_device_mesh,
+    get_is_master,
+    get_world_size,
+    init_signal_handler,
+    parallelize_model,
+    requeue_slurm_job,
+    setup_env,
+    setup_torch_distributed,
+)
+from bytelatent.logger import init_logger
+from bytelatent.metrics import (
+    GPUMemoryMonitor,
+    LoggingArgs,
+    MetricLogger,
+    get_num_params,
+)
+from bytelatent.optim import OptimArgs, build_optimizer
+from bytelatent.probe import AutoProbeD
+from bytelatent.profiling import ProfilerArgs, maybe_run_profiler
+from bytelatent.stool import StoolArgs, launch_job
+from bytelatent.transformer import (
+    LMTransformer,
+    LMTransformerArgs,
+    build_fsdp_grouping_plan,
+    get_no_recompute_ops,
+    get_num_flop_per_token,
+    tp_parallelize,
+)
+
+logger = logging.getLogger()
+
+
+class TrainArgs(BaseModel):
+    name: str = "lingua"
+    dump_dir: str = ""
+
+    seed: int = 42
+
+    # Number of gradient accumulation steps
+    # Total batch size is batch_size*grad_acc_steps
+    grad_acc_steps: int = 1
+
+    gc_collect_freq: int = 1000
+    probe_freq: int | None = None
+
+    # Nb optimizer steps to take
+    steps: int = 1000
+
+    data: DataArgs
+    optim: OptimArgs
+    model: LMTransformerArgs
+    distributed: DistributedArgs
+    env: EnvironmentArgs
+
+    checkpoint: CheckpointArgs
+    profiling: ProfilerArgs
+    logging: LoggingArgs
+
+    # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
+    async_eval_gpus: int | None = None
+    eval: Any | None = None
+
+
+@dataclass
+class TrainState(Stateful):
+    step: int  # Nb of steps taken by the optimizer
+    acc_step: int  # Nb of accumulation steps done since last optimizer step
+    scheduler: lr_scheduler.LambdaLR
+    data_loader_state: PackTokensState
+
+    def state_dict(self) -> Dict[str, Any]:
+        return {
+            "step": self.step,
+            "acc_step": self.acc_step,
+            "data_loader_state": self.data_loader_state,
+            "scheduler": self.scheduler.state_dict(),
+        }
+
+    def load_state_dict(self, state_dict):
+        self.step = state_dict["step"]
+        self.acc_step = state_dict["acc_step"]
+        self.data_loader_state = PackTokensState(**state_dict["data_loader_state"])
+        self.scheduler.load_state_dict(state_dict["scheduler"])
+
+
+def validate_train_args(args: TrainArgs, output_size: int):
+    if args.model.vocab_size < 0:
+        logger.info(f"Setting model output size to {args.model.vocab_size}")
+        args.model.vocab_size = output_size
+    assert (
+        args.model.vocab_size == output_size
+    ), "Vocab size should be the same as output size"
+
+    assert args.dump_dir, "Dump dir not set"
+
+    if args.checkpoint.path is None:
+        logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
+        args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
+
+    for source in args.data.sources:
+        data_path = os.path.join(args.data.root_dir, source)
+        assert os.path.exists(data_path), f"{data_path} doesn't exist"
+
+    if (
+        args.distributed.dp_replicate
+        * args.distributed.dp_shard
+        * args.distributed.tp_size
+        != get_world_size()
+    ):
+        assert get_world_size() % args.distributed.dp_shard == 0
+        args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
+
+        assert args.distributed.dp_replicate % args.distributed.tp_size == 0
+        args.distributed.dp_replicate = (
+            args.distributed.dp_replicate // args.distributed.tp_size
+        )
+
+        logger.warning(
+            f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
+        )
+        assert (
+            args.distributed.dp_replicate
+            * args.distributed.dp_shard
+            * args.distributed.tp_size
+            == get_world_size()
+        )
+
+        if args.distributed.fsdp_type == "no_shard":
+            assert (
+                args.distributed.dp_shard == 1
+                and args.distributed.dp_replicate == get_world_size()
+            )
+
+    args.model.max_seqlen = args.data.seq_len
+
+    if args.distributed.tp_size == 1:
+        logger.warning(
+            "Tensor parallelism has not been tested for a while, use at your own risk"
+        )
+
+    assert (
+        args.probe_freq != args.profiling.mem_steps
+    ), "Don't profile during probe step"
+    assert (
+        args.probe_freq != args.profiling.profile_steps
+    ), "Don't profile during probe step"
+    if args.logging.wandb is not None:
+        args.logging.wandb.name = args.name
+
+    if args.probe_freq is not None:
+        assert (
+            args.distributed.tp_size == 1
+        ), "Probing not supported with tensor parallelism"
+        assert (
+            args.distributed.selective_activation_checkpointing is False
+        ), "Probing not supported with selective activation checkpointing"
+
+
+preemption_flag = dict(flag=False)
+
+
+def set_preemption_flag(signum, frame):
+    logger.warning("Signal handler called with signal " + str(signum))
+    logger.warning("Preemption ! checkpointing asap and exiting.")
+    preemption_flag["flag"] = True
+
+
+def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
+    test = train_state.step % freq == 0
+    if acc_step is not None:
+        test = test and (train_state.acc_step == acc_step)
+    elif acc_freq is not None:
+        test = test and ((train_state.acc_step % acc_freq) == 0)
+    return test
+
+
+def train(args: TrainArgs):
+    with ExitStack() as context_stack:
+        tokenizer_args = TokenizerArgs(
+            name=args.data.name,
+            init_kwargs=args.data.tokenizer.init_kwargs,
+        )
+        tokenizer = tokenizer_args.build()
+        validate_train_args(
+            args,
+            tokenizer.n_words,
+        )
+        if get_is_master():
+            os.makedirs(args.dump_dir, exist_ok=True)
+            dump_config(args, Path(args.dump_dir) / "config.yaml")
+        init_logger(Path(args.dump_dir) / "train.log")
+        init_signal_handler(set_preemption_flag)  # For handling preemption signals.
+        setup_env(args.env)
+        setup_torch_distributed(args.distributed)
+        world_mesh = get_device_mesh(args.distributed)
+        logger.info(f"Starting job: {args.name}")
+
+        # build dataloader
+        # need dp world size and rank
+        dp_mesh = world_mesh["dp_replicate"]
+        dp_degree = dp_mesh.size()
+        dp_rank = dp_mesh.get_local_rank()
+        if args.distributed.dp_shard > 1:
+            dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
+            dp_degree *= world_mesh["dp_shard"].size()
+
+        logger.info(f"Running on dp rank : {dp_rank}")
+        logger.info(f"Running on dp size : {dp_degree}")
+
+        torch.manual_seed(args.seed)
+        logger.info("Building model")
+
+        # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
+        with torch.device("meta"):
+            model = LMTransformer(args.model)
+        logger.info("Model is built !")
+
+        model_param_count = get_num_params(model)
+
+        model = parallelize_model(
+            model,
+            world_mesh,
+            args.model,
+            args.distributed,
+            fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
+            tp_parallelize=tp_parallelize,
+            no_recompute_ops=get_no_recompute_ops(),
+        )
+
+        # Once we shard the model on different gpus we can actually initialize the model
+        # First we create empty tensors of the correct shapes
+        model = model.to_empty(device="cuda")
+        # Then we init the model. Please make sure this function initializes *ALL* parameters
+        # and buffers, otherwise you will have random values in the unitialized tensors
+        # which will silently fail (give nan gradients for example)
+
+        if args.checkpoint.init_ckpt_path:
+            logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
+            load_from_checkpoint(
+                args.checkpoint.init_ckpt_path, model, model_key="model"
+            )  # Put model_key="" if its directly the model checkpoint
+            model.rope_embeddings.reset_parameters()  # For RoPe initialization since it's a buffer it might not be loaded
+        else:
+            with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
+                torch.manual_seed(args.model.seed)
+                model.init_weights()
+        check_model_value_range(model, range=10.0, std=1.0)
+
+        # log model size
+
+        logger.info(f"Model size: {model_param_count:,} total parameters")
+
+        gpu_memory_monitor = GPUMemoryMonitor("cuda")
+        logger.info(
+            f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
+            f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
+        )
+        logger.info(f"GPU memory usage: {gpu_memory_monitor}")
+
+        # build optimizer after apply parallelisms to the model
+        optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
+        data_loader_state = init_dataloader_state_from_args(
+            args.data, dp_rank, dp_degree
+        )
+
+        train_state = TrainState(
+            step=0,
+            acc_step=0,
+            data_loader_state=data_loader_state,
+            scheduler=scheduler,
+        )
+
+        checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
+        checkpoint.load(model, optimizer, train_state, world_mesh)
+        # Either load from latest checkpoint or start from scratch
+        if args.probe_freq is not None:
+            if get_is_master():
+                os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
+            torch.distributed.barrier()
+            probe = AutoProbeD(
+                model,
+                (
+                    Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
+                    if (dp_rank % 128 == 0)
+                    else None
+                ),
+            )
+            probe_mod = model._orig_mod if args.distributed.compile else model
+
+        gc.disable()
+
+        # train loop
+        model.train()
+        metric_logger = context_stack.enter_context(
+            MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
+        )
+        data_loader = context_stack.enter_context(
+            build_dataloader_from_args(
+                args.data,
+                state=train_state.data_loader_state,
+            )
+        )
+        torch_profiler = context_stack.enter_context(
+            maybe_run_profiler(args.dump_dir, model, args.profiling)
+        )
+
+        nwords_since_last_log = 0
+        time_last_log = timer()
+        gc.collect()
+        while train_state.step < args.steps:
+            # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
+            train_state.acc_step += 1
+            train_state.acc_step = train_state.acc_step % args.grad_acc_steps
+
+            # get batch
+            curr_lr = float(optimizer.param_groups[0]["lr"])
+            data_load_start = timer()
+            batch, train_state.data_loader_state = next(data_loader)
+            batch = torch.tensor(
+                batch,
+                dtype=torch.long,
+            )
+
+            if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
+                logger.info("garbage collection")
+                # we do garbage collection manually otherwise different processes
+                # run the GC at different times so they slow down the whole pipeline
+                gc.collect()
+
+            input_ids = batch[:, :, 0].cuda()
+            labels = batch[:, :, 1].cuda()
+            data_load_time = round(timer() - data_load_start, 4)
+            nwords_since_last_log += input_ids.numel()
+
+            bsz, seqlen = labels.shape
+
+            # forward
+            start_timer = torch.cuda.Event(enable_timing=True)
+            end_timer = torch.cuda.Event(enable_timing=True)
+            start_timer.record()
+
+            # This is an automatic probe that will compute statistics
+            # of all linears' inputs, weights and outputs
+            # along with attention logits and entropy
+            # both in forward and backward pass
+            if (args.probe_freq is not None) and every_n_steps(
+                train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps
+            ):
+                # Here we do a fake forward and backward pass on a smaller
+                # batch size to avoid OOM
+                # This assumes the model has no stateful layers (batch norm..)
+                assert (
+                    next(probe_mod.parameters()).grad is None
+                ), "Can't probe model if grads are not reset"
+
+                with probe:
+                    probe.metadata = {
+                        "it": train_state.step,
+                        "global_step": train_state.step,
+                        "loop": "lingua",
+                    }
+                    # Non compiled model uses roughly 2x memory in our exps
+                    # So we divide bsz by 2 or seqlen by 2
+                    probe_bsz = max(1, bsz // 2)
+                    probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2)
+                    probe_loss = probe_mod(
+                        input_ids[:probe_bsz, :probe_seq],
+                        labels[:probe_bsz, :probe_seq],
+                    )
+                    probe_loss.backward()
+                    # We zero grads to cancel this fake step
+                    optimizer.zero_grad()
+
+                assert (
+                    next(probe_mod.parameters()).grad is None
+                ), "Probe model shouldn't have grads at this point"
+            loss = model(input_ids, labels)
+
+            # We scale loss with grad_acc_steps so the gradient is the same
+            # regardless of grad_acc_steps
+            loss = loss / args.grad_acc_steps
+            # backward on scaled loss to create scaled gradients
+            loss.backward()
+            # For logging we undo that scaling
+            loss = loss.detach() * args.grad_acc_steps
+
+            grad_norm = torch.nn.utils.clip_grad_norm_(
+                model.parameters(), max_norm=args.optim.clip, foreach=True
+            )
+
+            grad_norm = (
+                grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
+            ).item()
+
+            # optimizer step
+            if train_state.acc_step == 0:
+                optimizer.step()
+                scheduler.step()
+                optimizer.zero_grad()
+                train_state.step += 1
+
+            # updates the scale for next iteration
+            # training iteration complete
+            end_timer.record()
+
+            torch.cuda.synchronize()
+
+            curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
+
+            # if profiler is active
+            if torch_profiler:
+                xformers.profiler.step()
+
+            # log metrics
+            if every_n_steps(
+                train_state,
+                args.logging.freq,
+                acc_step=None if args.logging.acc_freq else 0,
+                acc_freq=args.logging.acc_freq,
+            ):
+                time_delta = timer() - time_last_log
+                wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
+
+                gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
+
+                total_acc_steps = (
+                    args.grad_acc_steps * train_state.step + train_state.acc_step
+                )
+                tokens_per_gpu = (
+                    total_acc_steps * args.data.batch_size * args.data.seq_len
+                )
+                total_tokens = dp_degree * tokens_per_gpu
+                # This is an estimate and the correct values may change
+                # if you change the architecture
+                # Use xformer's analyze profile trace to get actual measurement
+                FLOPS = (
+                    get_num_flop_per_token(
+                        model_param_count - args.model.vocab_size * args.model.dim,
+                        args.model.n_layers,
+                        args.model.dim,
+                        args.data.seq_len,
+                    )
+                    * wps
+                )
+                metrics = flatten_dict(
+                    {
+                        "global_step": train_state.step,
+                        "acc_step": train_state.acc_step,
+                        "speed": {
+                            "wps": wps,
+                            "FLOPS": FLOPS,
+                            "curr_iter_time": curr_iter_time,
+                            "data_load_time": data_load_time,
+                        },
+                        "optim": {
+                            "grad_norm": grad_norm,
+                            "lr": curr_lr,
+                            "total_tokens": total_tokens,
+                        },
+                        "memory": gpu_mem_stats._asdict(),
+                    },
+                    sep="/",
+                )
+
+                to_sync = {}
+                to_sync["loss/out"] = loss.item()
+                metrics.update(dist_mean_dict(to_sync))
+
+                if get_is_master():
+                    metric_logger.log(metrics)
+
+                gpu_memory_monitor.reset_peak_stats()
+                nwords_since_last_log = 0
+                time_last_log = timer()
+                logger.info(
+                    f"step: {train_state.step}"
+                    f"  acc: {train_state.acc_step}"
+                    f"  loss: {round(loss.item(),4):>7}"
+                    f"  grad: {grad_norm:.2e}"
+                    f"  flops: {FLOPS:.2e}"
+                    f"  wps: {wps:.2e}"
+                    f"  iter: {curr_iter_time:>7}"
+                    f"  data: {data_load_time:>5}"
+                    f"  lr: {curr_lr:.2e}"
+                    f"  mem: {gpu_mem_stats.max_active_pct:.0f}%"
+                    f"  pow: {gpu_mem_stats.power_draw/1000} W"
+                )
+
+            saved = False
+            if every_n_steps(
+                train_state, args.checkpoint.dump.every, acc_step=0
+            ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
+                saved = checkpoint.save(
+                    model,
+                    optimizer,
+                    train_state,
+                    args,
+                    device_mesh=world_mesh,
+                )
+
+            if args.eval is not None and every_n_steps(
+                train_state, args.checkpoint.eval.every, acc_step=0
+            ):
+                from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
+
+                eval_args = dataclass_from_dict(EvalArgs, args.eval)
+
+                eval_args.global_step = train_state.step
+                eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
+                eval_args.dump_dir = str(
+                    os.path.join(
+                        args.dump_dir,
+                        "evals",
+                        EVAL_FOLDER_NAME.format(train_state.step),
+                    )
+                )
+                eval_args.metric_log_dir = args.dump_dir
+                if args.async_eval_gpus is None:
+                    launch_eval(eval_args)
+                elif get_is_master():
+                    if wandb.run is not None and args.logging.wandb is not None:
+                        eval_args.wandb = deepcopy(args.logging.wandb)
+                    assert args.async_eval_gpus > 0
+                    logger.info(f"Launching evals on {args.async_eval_gpus} gpus")
+                    with clean_env():
+                        launch_job(
+                            StoolArgs(
+                                asdict(eval_args),
+                                script="apps.main.eval",
+                                copy_code=False,
+                                nodes=args.async_eval_gpus // 8,
+                                qos="lowest",
+                            )
+                        )
+
+            if preemption_flag["flag"]:
+                if not saved:
+                    checkpoint.save(
+                        model,
+                        optimizer,
+                        train_state,
+                        args,
+                        device_mesh=world_mesh,
+                    )
+                requeue_slurm_job()
+                sys.exit(0)
+
+    if not saved:
+        checkpoint.save(
+            model,
+            optimizer,
+            train_state,
+            args,
+            device_mesh=world_mesh,
+        )
+    gc.collect()
+
+
+def main():
+    """
+    The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
+    This accepts arguments as a dot list
+    So if the dataclass looks like
+
+    @dataclass
+    class DummyArgs:
+        name: str
+        model: LMTransformerArgsgs
+
+    @dataclass
+    class LMTransformerArgsgs:
+        dim: int
+
+    Then you can pass model.dim=32 to change values in LMTransformerArgsgs
+    or just name=tictac for top level attributes.
+
+    The behavior here is as follows:
+    1. We instantiate TrainArgs with its default values
+    2. We override those default values with the ones in the provided config file
+    3. We override the result with the additional arguments provided through command line
+
+    For example, if the config is the following
+
+    model:
+        dim: 128
+        n_layers: 4
+
+    and you call train.py with train.py model.dim=64
+
+    Then the final TrainArgs will have
+
+    model:
+        dim: 64
+        n_layers: 4
+
+    Plus all the default values in TrainArgs dataclass.
+    """
+    cli_args = OmegaConf.from_cli()
+    file_cfg = OmegaConf.load(cli_args.config)
+    # We remove 'config' attribute from config as the underlying DataClass does not have it
+    del cli_args.config
+
+    default_cfg = OmegaConf.structured(TrainArgs())
+    cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
+    cfg = OmegaConf.to_object(cfg)
+
+    train(cfg)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/blt-figure.jpg b/blt-figure.jpg
new file mode 100644
index 0000000..e27edc6
Binary files /dev/null and b/blt-figure.jpg differ
diff --git a/blt-figure.pdf b/blt-figure.pdf
new file mode 100644
index 0000000..045f15d
Binary files /dev/null and b/blt-figure.pdf differ
diff --git a/bytelatent/.DS_Store b/bytelatent/.DS_Store
new file mode 100644
index 0000000..5a50b07
Binary files /dev/null and b/bytelatent/.DS_Store differ
diff --git a/bytelatent/__init__.py b/bytelatent/__init__.py
new file mode 100644
index 0000000..5bc057f
--- /dev/null
+++ b/bytelatent/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+class ByteLatentError(Exception):
+    pass
diff --git a/bytelatent/args.py b/bytelatent/args.py
new file mode 100644
index 0000000..4fba100
--- /dev/null
+++ b/bytelatent/args.py
@@ -0,0 +1,199 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import logging
+import os
+from typing import Any
+
+import numpy as np
+import yaml
+from pydantic import BaseModel, ConfigDict
+
+from bytelatent.checkpoint import CheckpointArgs
+from bytelatent.data.data_types import Batch
+from bytelatent.data.iterators.abstract_iterator import StatefulIterator
+from bytelatent.data.iterators.arrow_iterator import (
+    ArrowFileIterator,
+    find_and_sanitize_chunks,
+)
+from bytelatent.data.iterators.looping_iterator import LoopingIterator
+from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
+from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
+from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
+from bytelatent.data.iterators.sampling_iterator import SamplingIterator
+from bytelatent.data.iterators.sequence_iterator import (
+    SequenceIterator,
+    SequencePackingArgs,
+)
+from bytelatent.data.patcher import PatcherArgs
+from bytelatent.distributed import DistributedArgs, EnvironmentArgs
+from bytelatent.metrics import LoggingArgs
+from bytelatent.model.blt import ByteLatentTransformerArgs
+from bytelatent.optim import OptimArgs
+from bytelatent.profiling import ProfilerArgs
+from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
+
+logger = logging.getLogger()
+
+
+def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
+    return np.random.default_rng((seed, rank, world_size)).bit_generator.state
+
+
+def distribute_data_to_rank(
+    *,
+    dataset_path: str,
+    preprocess_dir: str,
+    entropy_model_name: str | None,
+    arrow_batch_size: int,
+    rank: int,
+    world_size: int,
+) -> ArrowFileIterator:
+    dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size)
+    n_workers_per_chunk = world_size // len(dataset_chunks)
+    rank_to_arrow_iterator_params = []
+    for chunk_path in dataset_chunks:
+        for worker_id in range(n_workers_per_chunk):
+            rank_to_arrow_iterator_params.append(
+                ArrowFileIterator(
+                    file_path=chunk_path,
+                    worker_id=worker_id,
+                    num_workers=n_workers_per_chunk,
+                    preprocess_dir=preprocess_dir,
+                    dataset_files=None,
+                    entropy_model_name=entropy_model_name,
+                    arrow_batch_size=arrow_batch_size,
+                )
+            )
+    return rank_to_arrow_iterator_params[rank]
+
+
+class DataloaderArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    root_dir: str | None = None
+    sources: dict[str, float] = {}
+    batch_size: int = 2
+    seq_len: int = 2048
+    seed: int = 42
+    add_bos: bool = True
+    add_eos: bool = True
+    load_async: bool = True
+    prefetch_size: int = 64
+    preprocess_dir: str | None = None
+    dataset_files: list[str] | None = None
+    entropy_model_name: str | None = "transformer_100m"
+    arrow_batch_size: int = 100
+    buffer_size: int = 64
+
+    pad_to_max_length: bool = True
+    max_encoder_seq_length: int = 12288
+    enable_byte_ngrams: bool = False
+
+    tokenizer_args: TokenizerArgs = TokenizerArgs()
+    patcher_args: PatcherArgs = PatcherArgs()
+
+    def _create_sequence_iterators(
+        self, rank: int, world_size: int
+    ) -> dict[str, SequenceIterator]:
+        sequence_packing_args = SequencePackingArgs(
+            output_seq_len=self.seq_len,
+            buffer_size=self.buffer_size,
+        )
+        source_to_sequence_iterator: dict[str, SequenceIterator] = {}
+        for dataset_path in self.sources:
+            shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
+            arrow_iterator = distribute_data_to_rank(
+                dataset_path=os.path.join(self.root_dir, dataset_path),
+                preprocess_dir=self.preprocess_dir,
+                entropy_model_name=self.entropy_model_name,
+                arrow_batch_size=self.arrow_batch_size,
+                rank=rank,
+                world_size=world_size,
+            )
+            looping_iterator = LoopingIterator(arrow_iterator)
+            preprocess_iterator = PreprocessIterator(
+                looping_iterator,
+                patcher_args=self.patcher_args,
+                tokenizer_args=self.tokenizer_args,
+            )
+            sequence_iterator = SequenceIterator(
+                preprocess_iterator,
+                sequence_packing_args=sequence_packing_args,
+                rng_state=shuffle_rng_state,
+            )
+
+            source_to_sequence_iterator[dataset_path] = sequence_iterator
+        return source_to_sequence_iterator
+
+    def build_from_rank(
+        self, rank: int, world_size: int
+    ) -> StatefulIterator[Batch, Any]:
+        source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size)
+        weight_rng_state = get_rng_state(self.seed + 1, rank, world_size)
+        sampling_iterator = SamplingIterator(
+            rng_state=weight_rng_state,
+            source_to_weight=self.sources,
+            source_to_iterator=source_to_sequence_iterators,
+        )
+        tokenizer = self.tokenizer_args.build()
+        packing_args = PackingArgs(
+            batch_size=self.batch_size,
+            seq_len=self.seq_len,
+            pad_id=tokenizer.boe_id,
+            max_length=self.max_encoder_seq_length,
+            pad_to_max_length=self.pad_to_max_length,
+            enable_byte_ngrams=self.enable_byte_ngrams,
+        )
+        packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
+        mp_iterator = MultiprocessIterator(
+            packing_iterator, n_batches_to_prefetch=self.prefetch_size
+        )
+
+        return mp_iterator
+
+
+class TrainArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    name: str = "lingua"
+    dump_dir: str = ""
+
+    seed: int = 42
+
+    # Number of gradient accumulation steps
+    # Total batch size is batch_size*grad_acc_steps
+    grad_acc_steps: int = 1
+
+    gc_collect_freq: int = 1000
+    probe_freq: int | None = None
+
+    # Nb optimizer steps to take
+    steps: int = 1000
+
+    data: DataloaderArgs = DataloaderArgs()
+    optim: OptimArgs = OptimArgs()
+    model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
+    distributed: DistributedArgs = DistributedArgs()
+    env: EnvironmentArgs = EnvironmentArgs()
+
+    checkpoint: CheckpointArgs = CheckpointArgs()
+    profiling: ProfilerArgs = ProfilerArgs()
+    logging: LoggingArgs = LoggingArgs()
+
+    # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
+    async_eval_gpus: int | None = None
+    eval: Any | None = None
+    eval_on_gpus: int | None = None
+
+    def dump_to_yaml_file(
+        self, path: str, log_config: bool = True, sort_keys: bool = True
+    ):
+        model_dict = self.model_dump(mode="json")
+        yaml_str = yaml.dump(
+            model_dict,
+            allow_unicode=True,
+            sort_keys=sort_keys,
+            default_flow_style=False,
+        )
+        with open(path, "w") as f:
+            if log_config:
+                logger.info("Using the following config for this run:")
+                logger.info(yaml_str)
+            f.write(yaml_str)
diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py
new file mode 100644
index 0000000..f494a15
--- /dev/null
+++ b/bytelatent/base_transformer.py
@@ -0,0 +1,585 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from enum import Enum
+from typing import Optional, Tuple, Union
+
+import torch
+from pydantic import BaseModel
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.attention.flex_attention import (
+    BlockMask,
+    _mask_mod_signature,
+    flex_attention,
+)
+from xformers.ops import AttentionBias, fmha
+
+from bytelatent import probe
+
+flex_attention_comp = torch.compile(flex_attention)
+
+
+class InitStdFactor(Enum):
+    DISABLED = "disabled"  # Init std is divided by 1.0
+    GLOBAL_DEPTH = "global_depth"  # Init std is divided by sqrt(2*n_layers)
+    CURRENT_DEPTH = "current_depth"  # Init std is divided by sqrt(2*depth)
+    DIM_RATIO = "dim_ratio"  # Init std is divided by model_dim/4096
+
+
+class BaseTransformerArgs(BaseModel):
+    dim: int = 512
+    n_layers: int = 8
+    head_dim: Optional[int] = None
+    n_heads: Optional[int] = None
+    n_kv_heads: Optional[int] = None
+
+    ffn_dim_multiplier: Optional[float] = None
+
+    multiple_of: int = 256
+
+    norm_eps: float = 1e-5
+
+    rope_theta: float = 10000.0
+
+    init_base_std: Optional[float] = None
+    init_std_factor: InitStdFactor = InitStdFactor.DISABLED
+
+    max_seqlen: int = 1024
+
+
+def cross_entropy(pred, target, **kwargs):
+    return F.nll_loss(
+        F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
+        target.flatten(end_dim=-1),
+        **kwargs,
+    )
+
+
+def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
+    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
+    assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
+    bs, slen, n_kv_heads, head_dim = x.shape
+    if n_rep == 1:
+        return x
+    return (
+        x[:, :, :, None, :]
+        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+    )
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+    """
+    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+    and the end index 'end'. The 'theta' parameter scales the frequencies.
+    The returned tensor contains complex values in complex64 data type.
+
+    Args:
+        dim (int): Dimension of the frequency tensor.
+        end (int): End index for precomputing frequencies.
+        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+
+    Returns:
+        torch.Tensor: Precomputed frequency tensor with complex exponentials.
+    """
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+    t = torch.arange(end, device=freqs.device)
+    freqs = torch.outer(t, freqs).float()
+
+    cos, sin = freqs.cos(), freqs.sin()
+
+    return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
+    """
+    Reshape frequency tensor for broadcasting it with another tensor.
+
+    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+    for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+    Args:
+        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+        x (torch.Tensor): Target tensor for broadcasting compatibility.
+        seq_dim (int): Sequence dimension index.
+
+    Returns:
+        torch.Tensor: Reshaped frequency tensor.
+    """
+    ndim = x.ndim
+    assert 0 <= seq_dim < ndim
+    assert freqs_cis.shape == (
+        x.shape[seq_dim],
+        x.shape[-3],
+        2,
+        2,
+    ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
+    shape = [
+        d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
+    ] + [2, 2]
+    return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+    xq: torch.Tensor,
+    xk: torch.Tensor,
+    seq_dim: int,
+    freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)  # B S H D -> B S H D/2 1 2
+    xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)  # B S H D -> B S H D/2 1 2
+    freqs_cis = reshape_for_broadcast(
+        freqs_cis, xq_, seq_dim
+    ).float()  # S D/2 2 2 -> 1 S 1 D/2 2 2
+    xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
+    xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
+    return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+def causal_mask(b, h, q_idx, kv_idx):
+    return q_idx >= kv_idx
+
+
+def lengths_to_start_ids(lengths):
+    doc_start = lengths.cumsum(0)
+    doc_start = doc_start.roll(1)
+    doc_start[0] = 0
+    return doc_start
+
+
+def lengths_to_local_ids(lengths):
+    assert lengths.ndim == 1
+    nb_seqs = lengths.size(0)
+    total_seqlen = lengths.sum()
+    # This gives the document id of each token
+    doc_id = torch.repeat_interleave(lengths)
+    # Compute document start for each document
+    doc_start = lengths_to_start_ids(lengths)
+    # Compute document start for each token
+    doc_start = doc_start[doc_id]
+    # Compute the position of each token within each document
+    tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start
+
+    return doc_id, tok_id
+
+
+def generate_doc_mask_mod(
+    mask_mod: _mask_mod_signature,
+    lengths: torch.Tensor,
+    kv_lengths: Optional[torch.Tensor] = None,
+) -> _mask_mod_signature:
+    """Generates mask mods that apply to inputs to flex attention in the sequence stacked
+    format.
+
+    Args:
+        mask_mod: The mask mod to apply to the documents
+        lengths: Lengths of each document
+
+    Note:
+        What is the sequence stacked format? When assembling batches of inputs, we
+        take multiple sequences and stack them together to form 1 large sequence. We then
+        use masking to ensure that the attention scores are only applied to tokens within
+        the same document.
+
+    Example:
+
+    - Square mask
+      doc_mask         lengths
+      a a b b b c c    2 3 2
+    a 1 0 0 0 0 0 0
+    a 1 1 0 0 0 0 0
+    b 0 0 1 0 0 0 0
+    b 0 0 1 1 0 0 0
+    b 0 0 1 1 1 0 0
+    c 0 0 0 0 0 1 0
+    c 0 0 0 0 0 1 1
+
+    """
+    kv_lengths = kv_lengths if kv_lengths is not None else lengths
+    q_document_id, q_token_id = lengths_to_local_ids(lengths)
+    kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths)
+    q_max_idx = lengths.sum() - 1
+    kv_max_idx = kv_lengths.sum() - 1
+
+    def doc_mask_mod(b, h, q_idx, kv_idx):
+        q_idx_cap = torch.minimum(q_max_idx, q_idx)
+        kv_idx_cap = torch.minimum(kv_max_idx, kv_idx)
+        valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx)
+        same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap]
+        q_logical = q_token_id[q_idx_cap]
+        kv_logical = kv_token_id[kv_idx_cap]
+        inner_mask = mask_mod(b, h, q_logical, kv_logical)
+        return same_doc & inner_mask & valid_idx
+
+    return doc_mask_mod
+
+
+# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
+class RotaryEmbedding(torch.nn.Module):
+    """
+    RotaryEmbedding Module
+    """
+
+    def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
+        super().__init__()
+
+        self.theta = theta
+        self.head_dim = head_dim
+        self.max_seqlen = max_seqlen
+
+        self.register_buffer(
+            "freqs_cis",
+            precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
+            persistent=False,
+        )
+
+    def reset_parameters(self):
+        self.freqs_cis[...] = precompute_freqs_cis(
+            dim=self.head_dim, end=self.max_seqlen, theta=self.theta
+        )
+
+    def forward(
+        self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
+    ):
+        """
+        Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
+        Args:
+            seqlen (int): Contiguous sequence length
+            tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
+
+        Returns:
+            Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
+        """
+        test = (seqlen is not None) or (tok_idx is not None)
+        assert test, "Should provide atleast seqlen or tok_idx"
+        if tok_idx is not None:
+            return self.freqs_cis[tok_idx]
+        elif seqlen is not None:
+            return self.freqs_cis[0:seqlen]
+
+
+class RMSNorm(nn.Module):
+    """
+    Initialize the RMSNorm normalization layer.
+
+    Args:
+        dim (int): The dimension of the input tensor.
+        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+    Attributes:
+        eps (float): A small value added to the denominator for numerical stability.
+        weight (nn.Parameter): Learnable scaling parameter.
+
+    """
+
+    def __init__(self, dim: int, eps: float = 1e-6):
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x: torch.Tensor):
+        return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x: torch.Tensor):
+        x = probe.log_stats(x, "resid")
+        output = self._norm(x.float())
+        return (output * self.weight.float()).type_as(x)
+
+    def reset_parameters(self):
+        torch.nn.init.ones_(self.weight)  # type: ignore
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        head_dim: int,
+        n_heads: int,
+        n_kv_heads: int,
+        rope_theta: float,
+    ):
+        super().__init__()
+
+        self.dim = dim
+        self.head_dim = head_dim
+        self.rope_theta = rope_theta
+
+        self.n_heads = n_heads
+        self.n_kv_heads = n_kv_heads
+        self.heads_per_group = self.n_heads // self.n_kv_heads
+
+        self.wq = nn.Linear(
+            dim,
+            n_heads * head_dim,
+            bias=False,
+        )
+        self.wk = nn.Linear(
+            dim,
+            n_kv_heads * head_dim,
+            bias=False,
+        )
+        self.wv = nn.Linear(
+            dim,
+            n_kv_heads * head_dim,
+            bias=False,
+        )
+
+        self.wo = nn.Linear(
+            n_heads * head_dim,
+            dim,
+            bias=False,
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        freq_cis: torch.Tensor,
+        tok_idx: Optional[torch.Tensor] = None,
+        mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
+        attn_impl: str = "sdpa",
+    ) -> torch.Tensor:
+        # B S D
+        bsz, seq_len, dim = x.shape
+        xq = self.wq(x.view_as(x))
+        xk = self.wk(x.view_as(x))
+        xv = self.wv(x.view_as(x))
+
+        output_shape = xq.shape
+        # B S D -> B S H D
+        xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
+        xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+        xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+
+        xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
+
+        # This condition helps us be easily compatible
+        # with inference by adding a pluggable KVCache
+        if hasattr(self, "kv_cache"):
+            xk, xv = self.kv_cache.update(xk, xv, tok_idx)
+
+        xk = repeat_kv(xk, self.heads_per_group, dim=2)
+        xv = repeat_kv(xv, self.heads_per_group, dim=2)
+
+        if attn_impl == "flex_attention":
+            assert mask is None or isinstance(mask, BlockMask)
+            xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
+            output = flex_attention_comp(xq, xk, xv, block_mask=mask)
+            output = output.transpose(1, 2).contiguous()  # B H S D -> B S H D
+
+        elif attn_impl == "fmha":
+            assert mask is None or isinstance(mask, AttentionBias)
+            output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
+            # This uses B S H D instead of B H S D of pytorch
+
+        elif attn_impl == "sdpa":
+            xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
+            assert mask is None or isinstance(mask, (str, torch.Tensor))
+            is_causal = (mask == "causal") if isinstance(mask, str) else False
+            mask = mask if isinstance(mask, torch.Tensor) else None
+            output = F.scaled_dot_product_attention(
+                xq,
+                xk,
+                xv,
+                is_causal=is_causal,
+                attn_mask=mask,
+            )
+            output = output.transpose(1, 2).contiguous()  # B H S D -> B S H D
+        else:
+            raise NotImplementedError(
+                f"Attention implementation {attn_impl} not supported"
+            )
+
+        output = self.wo(output.reshape(output_shape))
+
+        return output
+
+    def reset_parameters(self, init_std=None, factor=1.0):
+        init_std = init_std or (self.dim ** (-0.5))
+
+        for w in [self.wq, self.wk, self.wv]:
+            nn.init.trunc_normal_(
+                w.weight,
+                mean=0.0,
+                std=init_std,
+                a=-3 * init_std,
+                b=3 * init_std,
+            )
+
+        nn.init.trunc_normal_(
+            self.wo.weight,
+            mean=0.0,
+            std=init_std / factor,
+            a=-3 * init_std,
+            b=3 * init_std,
+        )
+
+
+class FeedForward(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        hidden_dim: int,
+        multiple_of: int,
+        ffn_dim_multiplier: Optional[float],
+        mp_size: int = 1,
+    ):
+        super().__init__()
+
+        hidden_dim = int(2 * hidden_dim / 3)
+        if ffn_dim_multiplier is not None:
+            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+        assert hidden_dim % mp_size == 0
+
+        self.dim = dim
+        self.hidden_dim = hidden_dim
+
+        self.w1 = nn.Linear(
+            dim,
+            hidden_dim,
+            bias=False,
+        )
+        self.w3 = nn.Linear(
+            dim,
+            hidden_dim,
+            bias=False,
+        )
+        self.w2 = nn.Linear(
+            hidden_dim,
+            dim,
+            bias=False,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # B S D
+        x1 = self.w1(x.view_as(x))
+        x3 = self.w3(x.view_as(x))
+        output = self.w2(F.silu(x1) * x3)
+        return output
+
+    def reset_parameters(self, init_std=None, factor=1.0):
+        in_init_std = init_std or (self.dim ** (-0.5))
+        out_init_std = init_std or (self.hidden_dim ** (-0.5))
+        in_init_std = in_init_std
+        out_init_std = out_init_std / factor
+        for w in [self.w1, self.w3]:
+            nn.init.trunc_normal_(
+                w.weight,
+                mean=0.0,
+                std=in_init_std,
+                a=-3 * in_init_std,
+                b=3 * in_init_std,
+            )
+        nn.init.trunc_normal_(
+            self.w2.weight,
+            mean=0.0,
+            std=out_init_std,
+            a=-3 * out_init_std,
+            b=3 * out_init_std,
+        )
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, args: BaseTransformerArgs):
+        super().__init__()
+
+        assert (args.head_dim is not None) or (
+            args.n_heads is not None
+        ), "Should specify at least head_dim or n_heads"
+        self.head_dim = args.head_dim or args.dim // args.n_heads
+        self.n_heads = args.n_heads or args.dim // args.head_dim
+        self.n_kv_heads = args.n_kv_heads or self.n_heads
+
+        assert args.n_heads % self.n_kv_heads == 0
+        assert args.dim % args.n_heads == 0
+
+        self.attention = Attention(
+            dim=args.dim,
+            head_dim=self.head_dim,
+            n_heads=self.n_heads,
+            n_kv_heads=self.n_kv_heads,
+            rope_theta=args.rope_theta,
+        )
+        self.feed_forward = FeedForward(
+            dim=args.dim,
+            hidden_dim=4 * args.dim,
+            multiple_of=args.multiple_of,
+            ffn_dim_multiplier=args.ffn_dim_multiplier,
+        )
+        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
+        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        freq_cis: torch.Tensor,
+        tok_idx: Optional[torch.Tensor] = None,
+        mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
+        attn_impl: str = "sdpa",
+    ) -> torch.Tensor:
+        h = x + self.attention(
+            self.attention_norm(x),
+            freq_cis,
+            tok_idx=tok_idx,
+            mask=mask,
+            attn_impl=attn_impl,
+        )
+        out = h + self.feed_forward(self.ffn_norm(h))
+        return out
+
+    def init_weights(self, init_std=None, factor=1.0):
+        self.attention.reset_parameters(init_std, factor)
+        self.attention_norm.reset_parameters()
+
+        self.feed_forward.reset_parameters(init_std, factor)
+        self.ffn_norm.reset_parameters()
+
+
+class BaseTransformer(nn.Module):
+    def __init__(self, args: BaseTransformerArgs):
+        super().__init__()
+        self.dim = args.dim
+        self.init_base_std = args.init_base_std
+        self.init_std_factor = InitStdFactor(args.init_std_factor)
+        self.max_seqlen = args.max_seqlen
+        self.rope_embeddings = RotaryEmbedding(
+            theta=args.rope_theta,
+            head_dim=args.head_dim or args.dim // args.n_heads,
+            max_seqlen=args.max_seqlen,
+        )
+
+        self.layers = nn.ModuleList()
+        for _ in range(args.n_layers):
+            self.layers.append(TransformerBlock(args))
+
+    def forward(
+        self,
+        h,
+        tok_idx: Optional[torch.Tensor] = None,
+        mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
+        attn_impl: str = "sdpa",
+    ):
+
+        freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
+
+        for i, layer in enumerate(self.layers):
+            h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
+        return h
+
+    def reset_parameters(self):
+        # Either use fixed base std or sqrt model dim
+        self.rope_embeddings.reset_parameters()
+
+    def init_weights(self):
+        self.reset_parameters()
+        for depth, layer in enumerate(self.layers):
+            factor = {
+                InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
+                InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
+                InitStdFactor.DIM_RATIO: self.dim / 4096,
+                InitStdFactor.DISABLED: 1.0,
+            }[self.init_std_factor]
+
+            layer.init_weights(self.init_base_std, factor)
diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py
new file mode 100644
index 0000000..bcf591e
--- /dev/null
+++ b/bytelatent/checkpoint.py
@@ -0,0 +1,311 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import json
+import logging
+import os
+import re
+from pathlib import Path
+from typing import List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dcp
+import torch.nn as nn
+import torch.optim.optimizer
+from pydantic import BaseModel, ConfigDict
+from torch.distributed._tensor import DeviceMesh
+from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
+from torch.distributed.checkpoint.state_dict import (
+    get_model_state_dict,
+    get_state_dict,
+    set_state_dict,
+)
+
+from bytelatent.distributed import get_is_master
+
+logger = logging.getLogger("CHECKPOINT")
+
+FOLDER_NAME = "{:010d}"
+RE_FOLDER = r"\d{10}"
+
+RE_CKPT = r"__\d_\d\.distcp"
+
+CONSOLIDATE_FOLDER = "consolidated"
+CONSOLIDATE_NAME = "consolidated.pth"
+
+CONFIG_NAME = "params.json"
+TRAIN_STATE_NAME = "train_state_{:05d}.json"
+RE_DIGITS = re.compile(r"\d+")
+
+
+class SaveEvery(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    every: int = 1000
+    keep: int = 0
+
+
+class CheckpointArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    dump: SaveEvery = SaveEvery()
+    eval: SaveEvery = SaveEvery()
+    path: str | None = None
+    init_ckpt_path: str | None = None
+    continue_training_from_init: bool = False
+
+
+def _get_key_step(name: str):
+    return int(re.findall(RE_DIGITS, name)[-1])
+
+
+def consolidate_checkpoints(ckpt_dir: str):
+    """
+    Consolidates all FSDP checkpoints in a directory to a single file
+    Consolidate checkpoint is saved in a subdirectory of ckpt_dir
+
+    Parameters:
+        ckpt_dir: str - path to the directory containing the checkpoints
+
+    Returns the path to the consolidated checkpoint
+    """
+    consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
+    if not (consolidate_path / CONSOLIDATE_NAME).exists():
+        consolidate_path.mkdir(exist_ok=True)
+        logger.info(f"Consolidating to: {str(consolidate_path)}")
+        dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
+        (consolidate_path / CONFIG_NAME).write_text(
+            (Path(ckpt_dir) / CONFIG_NAME).read_text()
+        )
+        logger.info("Consolidated !")
+    return consolidate_path
+
+
+def load_from_checkpoint(
+    ckpt_dir: str,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    model_key: str = "model",
+    optim_key: str = "optim",
+):
+    if not (Path(ckpt_dir) / ".metadata").exists():
+        raise ValueError(
+            f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
+        )
+
+    state_dict = {}
+    if optimizer is not None:
+        state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer)
+    else:
+        state_dict[model_key] = get_model_state_dict(model)
+        if model_key == "":  # If only loading a model directly, the key should be empty
+            state_dict = state_dict.pop(model_key)
+
+    dcp.load(state_dict, checkpoint_id=ckpt_dir)
+
+
+class CheckpointManager:
+    def __init__(self, args: CheckpointArgs):
+        self.path = args.path
+        self.dump_every = args.dump
+        self.eval_every = args.eval
+        self.init_ckpt_path = args.init_ckpt_path
+        self.continue_training_from_init = args.continue_training_from_init
+
+        assert os.path.exists(
+            self.path
+        ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
+
+        self.existing_saves = self.get_existing_saves()
+
+    def get_existing_saves(self) -> List[Path]:
+        folders = [
+            p
+            for p in Path(self.path).iterdir()
+            if p.is_dir() and re.match(RE_FOLDER, p.name)
+        ]
+        folders.sort(key=lambda p: _get_key_step(p.name))
+        return folders
+
+    def clean_up(self):
+        logger.info("Cleaning up checkpoints...")
+        dump_folders = []
+        eval_folders = []
+        other_folders = []
+        for p in self.existing_saves:
+            is_dump = _get_key_step(p.name) % self.dump_every.every == 0
+            is_eval = _get_key_step(p.name) % self.eval_every.every == 0
+            if is_dump:
+                dump_folders.append(p)
+            if is_eval:
+                eval_folders.append(p)
+            if not (is_dump or is_eval):
+                other_folders.append(p)
+
+        logger.info(f"Dump folders: {dump_folders}")
+        logger.info(f"Eval folders: {eval_folders}")
+        logger.info(f"Other folders: {other_folders}")
+
+        if self.dump_every.keep > 0:
+            dump_folders = dump_folders[-self.dump_every.keep :]
+        if self.eval_every.keep > 0:
+            eval_folders = eval_folders[-self.eval_every.keep :]
+
+        folder_to_keep = set(other_folders + dump_folders + eval_folders)
+        folder_to_remove = set(self.existing_saves) - folder_to_keep
+
+        logger.info(f"Removing folders: {folder_to_remove}")
+
+        if dist.get_rank() == 0:
+            for folder in folder_to_remove:
+                for file in folder.iterdir():
+                    if file.is_file():
+                        file.unlink()
+                    elif file.is_dir():
+                        assert file.name in [CONSOLIDATE_FOLDER]
+                        for f in file.iterdir():
+                            f.unlink()
+                        file.rmdir()
+                folder.rmdir()
+
+        dist.barrier()
+
+        self.existing_saves = list(folder_to_keep)
+        self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
+
+    def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
+        path = None
+        for p in reversed(self.existing_saves):
+            if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
+                path = p
+                break
+        return path
+
+    def _create_folder(self, base_path: Path, folder_name: str) -> Path:
+        folder = base_path / folder_name
+        if get_is_master():
+            folder.mkdir(parents=False, exist_ok=True)
+        if dist.is_initialized():
+            dist.barrier()
+        return folder
+
+    def _get_dp_tp_mesh(
+        self, device_mesh: Optional[DeviceMesh] = None
+    ) -> Tuple[int, int]:
+        dp_rank = 0
+        tp_rank = 0
+        if device_mesh is not None:
+            if "dp_replicate" in device_mesh.mesh_dim_names:
+                dp_rank = device_mesh.get_local_rank("dp_replicate")
+                if "dp_shard" in device_mesh.mesh_dim_names:
+                    dp_rank = dp_rank * device_mesh[
+                        "dp_replicate"
+                    ].size() + device_mesh.get_local_rank("dp_shard")
+            if "tp" in device_mesh.mesh_dim_names:
+                tp_rank = device_mesh.get_local_rank("tp")
+        return dp_rank, tp_rank
+
+    @torch.no_grad()
+    def get_state_dict(
+        self,
+        model,
+        optimizer,
+    ):
+        model_sd, optim_sd = get_state_dict(model, optimizer)
+        return {"model": model_sd, "optim": optim_sd}
+
+    def save(
+        self,
+        model,
+        optimizer,
+        train_state,
+        config,
+        device_mesh: Optional[DeviceMesh] = None,
+    ) -> bool:
+
+        # When creating directory check if only rank0 or is there other solution
+        path = Path(self.path)
+        curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
+        logger.info(f"Saving to: {str(curr_save_dir)}")
+
+        if dist.is_initialized():
+            dist.barrier()
+
+        logger.info("Saving...")
+        state_dict = self.get_state_dict(model, optimizer)
+        dcp.save(state_dict, checkpoint_id=curr_save_dir)
+        logger.info("State dict saved!")
+
+        if dist.is_initialized():
+            dist.barrier()
+
+        if get_is_master():
+            config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
+
+        # Add json dump here
+        dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
+        if tp_rank == 0:
+            train_state_name = TRAIN_STATE_NAME.format(dp_rank)
+            logger.info(
+                f"Saving train state to: {str(curr_save_dir / train_state_name)}"
+            )
+            with open(curr_save_dir / train_state_name, "w") as f:
+                json.dump(train_state.state_dict(), f)
+            logger.info("Train state saved !")
+
+        self.existing_saves.append(curr_save_dir)
+
+        self.clean_up()
+
+        if dist.is_initialized():
+            dist.barrier()
+        return True
+
+    @torch.no_grad()
+    def load(
+        self,
+        model: nn.Module,
+        optimizer,
+        train_state,
+        device_mesh: DeviceMesh,
+        path: Optional[Path] = None,
+    ):
+        dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
+        # Loading tries to load the provided path, if not available the last saved step and finally from the init path
+        path = path or self.get_last_step_path(dp_rank=dp_rank)
+        # If none of those are available don't do anything
+        if path is None:
+            # If no checkpoints exist do nothing
+            return
+
+        # Only load train state if it's provided, the files exist and we're not loading from init path
+        train_state_name = TRAIN_STATE_NAME.format(dp_rank)
+        logger.info("Reloading train state")
+        with open(path / train_state_name, "r") as f:
+            train_state_dict = json.load(f)
+        train_state.load_state_dict(train_state_dict)
+        logger.info("Train state reloaded")
+
+        logger.info(f"Loading from: {str(path)}")
+        state_dict = self.get_state_dict(
+            model=model,
+            optimizer=optimizer,
+        )
+        dcp.load(state_dict, checkpoint_id=path)
+        logger.info("State dict loaded.")
+
+        logger.info("Reloading model and optim")
+
+        set_state_dict(
+            model,
+            optimizer,
+            model_state_dict=state_dict["model"],
+            optim_state_dict=state_dict["optim"],
+        )
+        logger.info("Model and optim reloaded")
+
+    @classmethod
+    def instantiate_and_make_dir(cls, args: CheckpointArgs):
+        if get_is_master():
+            os.makedirs(args.path, exist_ok=True)
+        dist.barrier()
+
+        return cls(args)
diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml
new file mode 100644
index 0000000..5f6debb
--- /dev/null
+++ b/bytelatent/configs/debug.yaml
@@ -0,0 +1,110 @@
+# Template config, need to change dump_dir, data.root_dir and tokenizer.path
+# Evals can be activated by uncommenting its config
+# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
+
+dump_dir: /tmp/
+name: "debug"
+steps: 100_000
+probe_freq: null
+seed: 777
+optim:
+  lr: 4e-04
+  warmup: 500
+  lr_min_ratio: 0.1
+  clip: 10.0
+
+distributed:
+  fsdp_type: full_shard
+  compile: true
+  model_dtype: bf16
+  matmul_allow_tf32: false
+  selective_activation_checkpointing: false
+  tp_size: 1
+
+model:
+  n_heads: 8
+  dim: 512
+  vocab_size: 260
+  dim_token: 256
+  patch_size: 6
+  tokenization_mode: "bytes"
+  patching_mode: "space"
+  tie_local_encoder_decoder_logits: false
+  data_loader_patching: true
+  max_encoder_seq_length: 12288
+  pad_to_max_length: true
+  patching_threshold: 3.1439168453216553
+  encoder_hash_byte_group_size: [4]
+  encoder_hash_byte_group_vocab: 50002
+  encoder_hash_byte_group_nb_functions: 3
+  encoder_enable_byte_ngrams: false
+  cross_attn_encoder: true # assuming cross_attention is true
+  cross_attn_decoder: true # assuming cross_attention is true
+  cross_attn_window_encoder: 512
+  cross_attn_window_decoder: 512
+  dim_local_encoder: 256
+  dim_local_decoder: 256
+  cross_attn_k: 8
+  cross_attn_nheads: 4
+  cross_attn_all_layers_decoder: true
+  cross_attn_all_layers_encoder: true
+  cross_attn_use_flex_attention: true
+  cross_attn_init_by_pooling: true
+  log_patch_lengths: true
+  non_linearity: "swiglu"
+  use_rope: true
+  recompute_fc1_out: false
+  recompute_fc3_out: false
+  recompute_attn: false
+  custom_bwd: false
+  layer_ckpt: "none"
+  efficient_attn: "sdpa"
+  patch_only_encoder: false
+  patch_only_decoder: false
+  use_local_encoder_transformer: true
+  init_use_gaussian: true
+  init_use_depth: "current"
+  attn_bias_type: "block_causal"
+  alpha_depth: "disabled"
+  max_length: 256
+  local_attention_window_len: 512
+  max_seqlen: 12288
+  downsampling_by_pooling: "max"
+
+data:
+  root_dir: ???
+  sources:
+    dclm_baseline_1.0: 1.0
+  batch_size: 2
+  prefetch_size: 64
+  seq_len: 4096
+  load_async: true
+  preprocess_dir: ???
+  tokenizer_args:
+    name: blt
+    init_kwargs:
+      bpe_tokenizer_path: ???
+
+profiling:
+  run: false
+
+checkpoint:
+  dump:
+    every: 500
+    keep: 3
+  eval:
+    every: 1000
+    keep: -1
+
+logging:
+  freq: 10
+
+eval_on_gpus: 8
+eval:
+  dataset_dir: /checkpoint/amaia/codegen/datasets/eval
+  tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
+  generator:
+    max_tokens: 65536
+    dtype: bf16
+
+  mp_size: 1
diff --git a/bytelatent/constants.py b/bytelatent/constants.py
new file mode 100644
index 0000000..341c7ff
--- /dev/null
+++ b/bytelatent/constants.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import os
+from pathlib import Path
+
+BLT_DATA = Path(os.environ.get("BLT_DATA", "data"))
diff --git a/bytelatent/data/__init__.py b/bytelatent/data/__init__.py
new file mode 100644
index 0000000..71ca4b1
--- /dev/null
+++ b/bytelatent/data/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py
new file mode 100644
index 0000000..7e142e4
--- /dev/null
+++ b/bytelatent/data/data_types.py
@@ -0,0 +1,115 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import json
+from dataclasses import dataclass
+from typing import Any, Iterator
+
+import numpy as np
+from pydantic import BaseModel, ConfigDict
+
+
+class BltExample(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    sample_id: str
+    text: str
+    tokens: list[int] | None
+    entropies: list[float] | None
+    patch_lengths: list[int] | None
+    mask: list[bool] | None
+
+
+class MultiChoiceState(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    root_dir: str
+    sources: dict[str, float]
+    source_to_state: dict[str, Any]
+    rng_state: dict[str, Any]
+
+
+class PrefetchState(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    seq_idx: int
+    rng_state: dict[str, Any]
+    prefetch_size: int
+    batch_size: int
+
+
+class BltPackTokensState(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    start_token: int
+    output_seq_len: int
+    n_views: int = 2
+
+
+class DataLoaderState(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    multi_choice_state: MultiChoiceState
+    pack_tokens_state: BltPackTokensState
+    prefetch_state: PrefetchState
+
+
+BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
+
+
+class BltSequence(BaseModel):
+    tokens: list[int]
+    mask: list[bool]
+    patch_lengths: list[int]
+
+
+@dataclass
+class Batch:
+    x: np.ndarray
+    y: np.ndarray
+    mask: np.ndarray | None = None
+    patch_lengths: np.ndarray | None = None
+    ngram_ids: np.ndarray | None = None
+    is_final: bool = False
+
+    def to_python_dict(self) -> dict:
+        x = self.x.tolist()
+        y = self.y.tolist()
+        if self.mask is None:
+            mask = None
+        else:
+            mask = self.mask.tolist()
+        if self.patch_lengths is None:
+            patch_lengths = None
+        else:
+            patch_lengths = self.patch_lengths.tolist()
+        if self.ngram_ids is None:
+            ngram_ids = None
+        else:
+            ngram_ids = self.ngram_ids.tolist()
+        return {
+            "x": x,
+            "y": y,
+            "mask": mask,
+            "patch_lengths": patch_lengths,
+            "ngram_ids": ngram_ids,
+            "is_final": self.is_final,
+        }
+
+    @classmethod
+    def from_python_dict(cls, data: dict) -> "Batch":
+        x = np.array(data["x"])
+        y = np.array(data["y"])
+        if data["mask"] is None:
+            mask = None
+        else:
+            mask = np.array(data["mask"])
+        if data["patch_lengths"] is None:
+            patch_lengths = None
+        else:
+            patch_lengths = np.array(data["patch_lengths"])
+        if data["ngram_ids"] is None:
+            ngram_ids = None
+        else:
+            ngram_ids = np.array(data["ngram_ids"])
+        return Batch(
+            x=x,
+            y=y,
+            mask=mask,
+            patch_lengths=patch_lengths,
+            ngram_ids=ngram_ids,
+            is_final=data["is_final"],
+        )
diff --git a/bytelatent/data/iterators/__init__.py b/bytelatent/data/iterators/__init__.py
new file mode 100644
index 0000000..71ca4b1
--- /dev/null
+++ b/bytelatent/data/iterators/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py
new file mode 100644
index 0000000..7fb442b
--- /dev/null
+++ b/bytelatent/data/iterators/abstract_iterator.py
@@ -0,0 +1,23 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import abc
+from typing import Any, Generator, Generic, TypeVar
+
+T = TypeVar("T")
+C = TypeVar("C")
+
+
+class StatefulIterator(Generic[T, C], abc.ABC):
+
+    @abc.abstractmethod
+    def get_state(self) -> C:
+        pass
+
+    @abc.abstractmethod
+    def create_iter(self) -> Generator[T, Any, None]:
+        pass
+
+
+class IteratorState(Generic[C]):
+    @abc.abstractmethod
+    def build(self) -> StatefulIterator[T, C]:
+        pass
diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py
new file mode 100644
index 0000000..df5f023
--- /dev/null
+++ b/bytelatent/data/iterators/arrow_iterator.py
@@ -0,0 +1,216 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import re
+from logging import getLogger
+from pathlib import Path
+from typing import Any, Generator
+
+import pyarrow as pa
+
+# pyarrow needs the initialization from this import
+import pyarrow.dataset  # pyright: ignore
+from pydantic import BaseModel, ConfigDict
+
+from bytelatent import ByteLatentError
+from bytelatent.data.data_types import BltExample
+from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
+
+logger = getLogger(__name__)
+
+
+class ArrowFileIteratorState(BaseModel, IteratorState):
+    model_config = ConfigDict(extra="forbid")
+    file_path: str | None
+    row_num: int
+    num_workers: int
+    worker_id: int
+    preprocess_dir: str | None
+    dataset_files: list[str] | None
+    entropy_model_name: str | None
+    arrow_batch_size: int = 100
+
+    def build(self) -> "ArrowFileIterator":
+        arrow_file = ArrowFileIterator(
+            file_path=self.file_path,
+            worker_id=self.worker_id,
+            num_workers=self.num_workers,
+            preprocess_dir=self.preprocess_dir,
+            entropy_model_name=self.entropy_model_name,
+            arrow_batch_size=self.arrow_batch_size,
+            dataset_files=self.dataset_files,
+        )
+        if self.row_num != 0:
+            arrow_file._set_row_num(self.row_num)
+        return arrow_file
+
+
+def shard_sort_key(file: str | Path):
+    match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
+    shard_number = int(match.group(1))
+    return shard_number
+
+
+class ArrowFileIterator(StatefulIterator):
+    def __init__(
+        self,
+        *,
+        file_path: str | None,
+        worker_id: int,
+        num_workers: int,
+        preprocess_dir: str | None,
+        entropy_model_name: str | None,
+        arrow_batch_size: int,
+        dataset_files: list[str] | None = None,
+    ):
+        assert 0 <= worker_id < num_workers, (worker_id, num_workers)
+        if file_path is None and dataset_files is None:
+            raise ByteLatentError("file_path and dataset_files cannot both be None")
+        self.row_num = 0
+        self.iter_id = 0
+        self.batch_iterator = None
+        self.batch_to_consume = None
+        self.dataset = None
+        self.file_path = file_path
+        self.worker_id = worker_id
+        self.num_workers = num_workers
+        self.preprocess_dir = preprocess_dir
+        self.entropy_model_name = entropy_model_name
+        self.arrow_batch_size = arrow_batch_size
+        if dataset_files is None:
+            # Prepare arrow shards
+            jsonl_file = Path(file_path)
+            parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
+            assert parts is not None
+            dataset = parts.group(1)
+            data_dir = Path(preprocess_dir) / dataset / entropy_model_name
+            shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
+            for s in shard_files:
+                if not (data_dir / f"{s.name}.complete").exists():
+                    raise ValueError(f"Missing .complete for input file: {s}")
+
+            shard_files = sorted(shard_files, key=shard_sort_key)
+            if len(shard_files) == 0:
+                raise ByteLatentError(
+                    f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
+                )
+            self.dataset_files = [str(f) for f in shard_files]
+        else:
+            self.preprocess_dir = None
+            self.dataset_files = dataset_files
+
+    def get_state(self) -> ArrowFileIteratorState:
+        return ArrowFileIteratorState(
+            file_path=self.file_path,
+            row_num=self.row_num,
+            worker_id=self.worker_id,
+            num_workers=self.num_workers,
+            preprocess_dir=self.preprocess_dir,
+            entropy_model_name=self.entropy_model_name,
+            arrow_batch_size=self.arrow_batch_size,
+            dataset_files=self.dataset_files,
+        )
+
+    def create_iter(
+        self,
+    ) -> Generator[BltExample, Any, None]:
+        if self.dataset is None:
+            self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
+            self.batch_iterator = self.dataset.to_batches(
+                batch_size=self.arrow_batch_size
+            )
+        self.iter_id += 1
+        if self.batch_to_consume is not None:
+            batch_columns: dict[str, list] = self.batch_to_consume
+            self.batch_to_consume = None
+            sample_ids = batch_columns["sample_id"]
+            texts = batch_columns["text"]
+            entropies = batch_columns["entropies"]
+            for i in range(len(sample_ids)):
+                out = BltExample(
+                    sample_id=sample_ids[i],
+                    entropies=entropies[i],
+                    text=texts[i],
+                    tokens=None,
+                    mask=None,
+                    patch_lengths=None,
+                )
+                self.row_num += 1
+                if (self.row_num - 1) % self.num_workers == self.worker_id:
+                    yield out
+
+        for batch in self.batch_iterator:
+            batch_columns = batch.to_pydict()
+            sample_ids = batch_columns["sample_id"]
+            texts = batch_columns["text"]
+            entropies = batch_columns["entropies"]
+            for i in range(len(sample_ids)):
+                out = BltExample(
+                    sample_id=sample_ids[i],
+                    entropies=entropies[i],
+                    text=texts[i],
+                    tokens=None,
+                    mask=None,
+                    patch_lengths=None,
+                )
+                self.row_num += 1
+                if (self.row_num - 1) % self.num_workers == self.worker_id:
+                    yield out
+
+    def _set_row_num(self, target_row_num: int):
+        logger.info(
+            f"Setting arrow position to {target_row_num} for {self.dataset_files}"
+        )
+        if target_row_num is None or target_row_num == 0:
+            self.row_num = 0
+            self.dataset = None
+            self.batch_iterator = None
+            self.batch_to_consume = None
+        else:
+            self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
+            self.batch_iterator = self.dataset.to_batches(
+                batch_size=self.arrow_batch_size
+            )
+            curr_remaining = target_row_num
+            for batch in self.batch_iterator:
+                if len(batch) > curr_remaining:
+                    batch_columns: dict[str, list] = batch.to_pydict()
+                    batch_columns["sample_id"] = batch_columns["sample_id"][
+                        curr_remaining:
+                    ]
+                    batch_columns["entropies"] = batch_columns["entropies"][
+                        curr_remaining:
+                    ]
+                    batch_columns["text"] = batch_columns["text"][curr_remaining:]
+                    self.batch_to_consume = batch_columns
+                    break
+                elif len(batch) == curr_remaining:
+                    # We are exactly at the end of the batch,
+                    # so the next batch is the right spot
+                    break
+                else:
+                    curr_remaining -= len(batch)
+            self.row_num = target_row_num
+        logger.info(
+            f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
+        )
+
+
+TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
+
+
+def find_and_sanitize_chunks(
+    dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
+):
+    dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
+    n_chunks = len(dataset_chunks)
+
+    if n_chunks > world_size:
+        n_discard = n_chunks - world_size
+        dataset_chunks = dataset_chunks[:world_size]
+    else:
+        assert (
+            world_size % n_chunks == 0
+        ), "World size should be a multiple of number of chunks"
+
+    assert n_chunks > 0, f"No valid chunks in {dataset_path}"
+
+    return dataset_chunks
diff --git a/bytelatent/data/iterators/looping_iterator.py b/bytelatent/data/iterators/looping_iterator.py
new file mode 100644
index 0000000..2eff38c
--- /dev/null
+++ b/bytelatent/data/iterators/looping_iterator.py
@@ -0,0 +1,36 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+from pydantic import BaseModel
+
+from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
+from bytelatent.data.iterators.arrow_iterator import (
+    ArrowFileIterator,
+    ArrowFileIteratorState,
+)
+
+
+class LoopingIteratorState(BaseModel, IteratorState):
+    file_iterator_state: ArrowFileIteratorState
+    epoch: int
+
+    def build(self) -> "LoopingIterator":
+        return LoopingIterator(
+            file_iterator=self.file_iterator_state.build(),
+            epoch=self.epoch,
+        )
+
+
+class LoopingIterator(StatefulIterator):
+    def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1):
+        self.file_iterator = file_iterator
+        self.epoch = epoch
+
+    def get_state(self):
+        return LoopingIteratorState(
+            file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch
+        )
+
+    def create_iter(self):
+        while True:
+            self.epoch += 1
+            iterator = self.file_iterator.create_iter()
+            yield from iterator
diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py
new file mode 100644
index 0000000..f17ca6e
--- /dev/null
+++ b/bytelatent/data/iterators/multiprocess_iterator.py
@@ -0,0 +1,243 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import json
+import logging
+import multiprocessing as mp
+from multiprocessing.synchronize import Event as EventClass
+from queue import Empty, Full
+
+import numpy as np
+from pydantic import BaseModel, ConfigDict
+
+from bytelatent.data.data_types import Batch
+from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
+from bytelatent.data.iterators.packing_iterator import PackingIteratorState
+
+logger = logging.getLogger()
+
+
+class MultiprocessIteratorState(BaseModel, IteratorState):
+    model_config = ConfigDict(extra="forbid")
+    base_iterator_state: PackingIteratorState
+    n_batches_to_prefetch: int
+    serialized_prefetch_buffer: str
+
+    def build(self):
+        base_iterator = self.base_iterator_state.build()
+        data = json.loads(self.serialized_prefetch_buffer)
+        prefetch_buffer = [Batch.from_python_dict(item) for item in data]
+        return MultiprocessIterator(
+            base_iterator,
+            n_batches_to_prefetch=self.n_batches_to_prefetch,
+            prefetch_buffer=prefetch_buffer,
+        )
+
+
+def start_work_from_state(
+    batch_queue: mp.Queue,
+    state_queue: mp.Queue,
+    stop_event: EventClass,
+    state_dumped_event: EventClass,
+    state: IteratorState,
+):
+    logging.info("Worker thread: Starting base_iterator work")
+    stateful_iterator = state.build()
+    iterator = stateful_iterator.create_iter()
+    for item in iterator:
+        while not stop_event.is_set():
+            try:
+                # Attempt to put on queue or timeout to try again (maybe main thread is busy)
+                batch_queue.put(item, timeout=0.1)
+                # On success, stop trying
+                break
+            except Full:
+                pass
+        if stop_event.is_set():
+            # Signal the end of output, this ensures that even if the queue takes a while to
+            # buffer, that the main thread receives everything (and tosses this fake batch)
+            logging.info(
+                "Worker thread: Stop event detected, outputting is_final=True batch"
+            )
+            batch_queue.put(
+                Batch(
+                    x=np.zeros((1, 1)),
+                    y=np.zeros((1, 1)),
+                    is_final=True,
+                    mask=None,
+                    patch_lengths=None,
+                    ngram_ids=None,
+                )
+            )
+            break
+
+    try:
+        logging.info("Worker thread: outputting state")
+        state_queue.put(iterator.get_state(), timeout=1)
+        logging.info("Worker thread: state dump complete")
+        state_dumped_event.set()
+        logging.info("Worker thread: set state_dump_event")
+    except Full:
+        raise ValueError(
+            "Attempted to dump state into the state queue, but it was full"
+        )
+
+
+class MultiprocessIterator(StatefulIterator):
+    """
+    Design sketch of the multiprocess iterator:
+
+    Given the base_iterator, the only thing we do with this is call get_state()
+    so that we can pass that through to the background worker process.
+
+    The background process will receive this, rebuild the iterator, then start yielding from it.
+
+    However, in order to implement MultiprocessIterator.get_state(), we need to be able to accurately get
+    (1) the state of the iterator in the worker process
+    (2) the currently buffered items in the Queue
+
+    To do this, we use:
+    - batch_queue: This is the prefetch buffer the worker yields to and the main loop yields from
+    - state_queue: This size 1 queue will be how the worker sends the iterator state once it has halted iterating.
+        It must hold the state in addition to the last batch, if the queue was full at the time the stop event is sent.
+    - stop_iterating_event: Once this is issued from the main loop, the worker will stop iterating and enter cleanup.
+        During cleanup, the iterator will send the state of the current iterator to the main loop,
+        in addition to possibly the last batch if the batch_queue was full at the time
+    - state_dumped_event: When the main loop issues the stop_iterating_event, it will wait until the state_dumped_event to attempt
+        to get state from the state_queue. It must do this since the worker may take some time to create and send the state.
+        Once received by the main loop, the main loop can safely store the Queue (plus maybe the last batch) as the prefetch buffer,
+        get the worker iterator's state, and terminate the background process + delete associated objects.
+
+    At this point, calling create_iter() again will bootstrap everything from the stored state and the old iterator will throw an error
+    since it will not iterate anymore (so the caller must call create_iter() again to get a python iterator).
+
+    """
+
+    def __init__(
+        self,
+        base_iterator: StatefulIterator,
+        *,
+        n_batches_to_prefetch: int,
+        prefetch_buffer: list | None = None
+    ):
+        self.base_iterator = base_iterator
+        self.n_batches_to_prefetch = n_batches_to_prefetch
+        if prefetch_buffer is None:
+            prefetch_buffer = []
+        self.prefetch_buffer = prefetch_buffer
+        self.batch_queue = None
+        self.state_queue = None
+        self.producer = None
+        self.stop_iterating_event = None
+        self.state_dumped_event = None
+
+    def get_state(self) -> MultiprocessIteratorState:
+        """
+        This is slightly unusual in effectively destroying the current iterator, its necessary
+        to halt the background process and allow it to write the state to the main loop
+        in order to not lose data
+        """
+        if self.producer is None:
+            serialized_prefetch_buffer = json.dumps(
+                [b.to_python_dict() for b in self.prefetch_buffer]
+            )
+            return MultiprocessIteratorState(
+                base_iterator_state=self.base_iterator.get_state(),
+                n_batches_to_prefetch=self.n_batches_to_prefetch,
+                serialized_prefetch_buffer=serialized_prefetch_buffer,
+            )
+        else:
+            logging.info("Main thread: Sending stop iteration event")
+            self.stop_iterating_event.set()
+            logging.info("Main thread: Waiting for state_dumped event")
+            self.state_dumped_event.wait()
+            self.prefetch_buffer = []
+            final_batch_received = False
+            while True:
+                try:
+                    batch = self.batch_queue.get(timeout=1)
+                    if batch.is_final:
+                        final_batch_received = True
+                        break
+                    self.prefetch_buffer.append(batch)
+                except Empty:
+                    logging.warning("Main thread: batch_queue is abnormally empty")
+            assert final_batch_received
+
+            try:
+                base_iterator_state = self.state_queue.get(timeout=1)
+                assert isinstance(base_iterator_state, IteratorState)
+            except Empty:
+                raise ValueError(
+                    "Attempted to get the state, but it was unexpectantly missing"
+                )
+
+            self.base_iterator = base_iterator_state.build()
+            self.producer.close()
+            self.producer = None
+            self.batch_queue = None
+            self.state_queue = None
+            self.stop_iterating_event = None
+            self.state_dumped_event = None
+
+            return MultiprocessIteratorState(
+                base_iterator_state=self.base_iterator.get_state(),
+                n_batches_to_prefetch=self.n_batches_to_prefetch,
+                serialized_prefetch_buffer=json.dumps(
+                    [b.to_python_dict() for b in self.prefetch_buffer]
+                ),
+            )
+
+    def create_iter(self):
+        logging.info("Main thread: Creating MP iterator")
+        # First yield from the stored prefetch buffer.
+        if self.prefetch_buffer is not None:
+            while len(self.prefetch_buffer) > 0:
+                item = self.prefetch_buffer.pop(0)
+                yield item
+            self.prefetch_buffer = None
+
+        assert (
+            self.producer is None
+        ), "Cannot create two parallel iterators at once, call get_state() then remake to have two."
+
+        # using mp context manager avoids excessive CPU loading
+        ctx = mp.get_context("forkserver")
+        self.batch_queue = ctx.Manager().Queue(maxsize=self.n_batches_to_prefetch)
+
+        # We should only ever one state, which is output at the detection of a stop event
+        self.state_queue = ctx.Manager().Queue(maxsize=1)
+
+        self.stop_iterating_event = ctx.Event()
+        self.state_dumped_event = ctx.Event()
+
+        self.producer = mp.Process(
+            name="blt_data_loader",
+            target=start_work_from_state,
+            args=(
+                self.batch_queue,
+                self.state_queue,
+                self.stop_iterating_event,
+                self.state_dumped_event,
+                self.base_iterator.get_state(),
+            ),
+        )
+        logger.info("Async dataloader started")
+        self.producer.start()
+
+        while True:
+            if self.producer.exitcode is not None:
+                raise RuntimeError(
+                    "Data loader quit unexpectedly, real error has been raised previously"
+                )
+            try:
+                batch = self.batch_queue.get(timeout=0.1)
+                assert isinstance(batch, Batch)
+                assert (
+                    not batch.is_final
+                ), "is_final should only be used during get_state() being called"
+                yield batch
+            except Empty:
+                pass
+            if self.producer is None:
+                raise ValueError(
+                    "Attempted to call this iterator after calling get_state(). You must call create_iter() to make a new iterator instead."
+                )
diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py
new file mode 100644
index 0000000..361fc03
--- /dev/null
+++ b/bytelatent/data/iterators/packing_iterator.py
@@ -0,0 +1,226 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+from typing import Any
+
+import numpy as np
+from pydantic import BaseModel, ConfigDict
+
+from bytelatent.data.data_types import Batch, BltSequence
+from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
+from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
+
+
+class PackingArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    batch_size: int
+    seq_len: int
+    pad_id: int
+    max_length: int | None
+    pad_to_max_length: bool
+    enable_byte_ngrams: bool
+
+
+class PackingIteratorState(BaseModel, IteratorState):
+    model_config = ConfigDict(extra="forbid")
+    sequence_iterator_state: SamplingIteratorState
+    packing_args: PackingArgs
+
+    def build(self) -> "PackingIterator":
+        return PackingIterator(
+            sequence_iterator=self.sequence_iterator_state.build(),
+            packing_args=self.packing_args,
+        )
+
+
+def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]):
+    assert len(mask_seqs) == bs
+    lens = [len(m) for m in mask_seqs]
+    if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
+        return None
+    assert slen == max(lens) - 1
+    mask = np.zeros((bs, slen), dtype=bool)
+    for i, m in enumerate(mask_seqs):
+        if m is None:
+            print(
+                "Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function."
+            )
+            raise NotImplementedError
+        mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:]
+    return mask
+
+
+def truncate_batch(
+    batch: Batch,
+    max_length: int,
+    pad_id: int,
+    pad_to_max_length: bool = False,
+    *,
+    enable_byte_ngrams: bool,
+):
+    """
+    Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts
+    and fixing the batch.mask.
+
+    batch.patch_lengths has unchanged shape
+    x,y, and mask may reduce in size
+    """
+    if batch.patch_lengths is None:
+        return batch
+
+    seq_lengths = batch.patch_lengths.sum(axis=1)
+    max_length_adj = max_length + 1
+    if np.any(seq_lengths > max_length_adj):
+        for i in range(batch.x.shape[0]):
+            if seq_lengths[i] > max_length_adj:
+                # Find id of patch that tips over max_length + 1
+                count, j = 0, 0
+                while count + batch.patch_lengths[i, j] <= max_length_adj:
+                    count += batch.patch_lengths[i, j]
+                    j += 1
+                # Edit the batch
+                assert j < batch.patch_lengths.shape[1]
+                batch.x[i, max_length:] = pad_id
+                batch.y[i, max_length:] = pad_id
+                if batch.mask is not None:
+                    batch.mask[i, max_length:] = False
+                batch.patch_lengths[i, j:] = 0
+                batch.patch_lengths[i, j] = max_length_adj - count
+
+        # Truncate if necessary.
+        if max_length < batch.x.shape[1]:
+            batch.x = batch.x[:, :max_length]
+            batch.y = batch.y[:, :max_length]
+            if batch.mask is not None:
+                batch.mask = batch.mask[:, :max_length]
+
+    # Right pad to max_length if necessary
+    elif pad_to_max_length:
+        if batch.x.shape[1] < max_length:
+            # NOTE: this has to be done on an actual patch.
+            non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1
+            non_zero_indices = np.maximum(0, non_zero_indices)
+            batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += (
+                max_length - batch.x.shape[1]
+            )
+            # TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader.
+            x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype)
+            x[:, : batch.x.shape[1]] = batch.x
+            batch.x = x
+        if batch.y.shape[1] < max_length:
+            y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
+            y[:, : batch.y.shape[1]] = batch.y
+            batch.y = y
+        if batch.mask is not None and batch.mask.shape[1] < max_length:
+            mask = np.full(
+                (batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
+            )
+            mask[:, : batch.mask.shape[1]] = batch.mask
+            batch.mask = mask
+
+    assert batch.x.shape[1] <= max_length
+    assert batch.y.shape[1] <= max_length
+    assert batch.mask is None or batch.mask.shape[1] <= max_length
+    assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0)
+    if pad_to_max_length:
+        assert batch.x.shape[1] == max_length
+        assert batch.y.shape[1] == max_length
+        assert batch.mask is None or batch.mask.shape[1] == max_length
+    if enable_byte_ngrams:
+        raise NotImplementedError()
+        # (num_ngram, batch_size, seq_len)
+        ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x))
+        assert ngram_ids.shape[2] == batch.x.shape[1]
+    else:
+        ngram_ids = None
+    batch.ngram_ids = ngram_ids
+
+
+class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
+    def __init__(
+        self,
+        sequence_iterator: StatefulIterator[BltSequence, Any],
+        *,
+        packing_args: PackingArgs,
+    ):
+        self.sequence_iterator = sequence_iterator
+        self.packing_args = packing_args
+
+    def get_state(self):
+        return PackingIteratorState(
+            sequence_iterator_state=self.sequence_iterator.get_state(),
+            packing_args=self.packing_args,
+        )
+
+    def create_iter(self):
+        sequence_iter = self.sequence_iterator.create_iter()
+        batch_size = self.packing_args.batch_size
+        pad_id = self.packing_args.pad_id
+        seq_len = self.packing_args.seq_len
+        pad_to_max_length = self.packing_args.pad_to_max_length
+        enable_byte_ngrams = self.packing_args.enable_byte_ngrams
+        max_length = self.packing_args.max_length
+        while True:
+            tokens: list[list[int]] = []
+            masks: list[list[bool]] = []
+            patch_lengths: list[list[int]] = []
+
+            for _ in range(self.packing_args.batch_size):
+                sequence = next(sequence_iter)
+                _tokens = sequence.tokens
+                _mask = sequence.mask
+                _patch_lengths = sequence.patch_lengths
+                assert len(sequence.patch_lengths) == self.packing_args.seq_len
+                last_patch_length = 0
+                if _patch_lengths[0] > 1:
+                    last_patch_length = _patch_lengths[-1]
+                    _patch_lengths[0] -= 1
+                    _patch_lengths = [1] + _patch_lengths[:-1]
+                tokens.append(_tokens[: len(_tokens) - last_patch_length])
+                masks.append(_mask[: len(_mask) - last_patch_length])
+                patch_lengths.append(_patch_lengths)
+
+            x_patch_lengths = np.array(patch_lengths)
+            # pad batch to same length
+            tok_seq_len = max([len(toks) for toks in tokens]) - 1
+            x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
+            y = np.full((batch_size, tok_seq_len), fill_value=pad_id)
+
+            for i, tok_seq in enumerate(tokens):
+                x[i, : len(tok_seq) - 1] = tok_seq[:-1]
+                y[i, : len(tok_seq) - 1] = tok_seq[1:]
+                # Adjust patch lengths to match x
+                x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
+
+            assert x_patch_lengths.shape == (batch_size, seq_len)
+
+            if enable_byte_ngrams:
+                raise NotImplementedError()
+            else:
+                ngram_ids = None
+
+            batch = Batch(
+                x=x,
+                y=y,
+                patch_lengths=x_patch_lengths,
+                ngram_ids=ngram_ids,
+                mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
+            )
+            assert (
+                x_patch_lengths.sum() == x.size + batch_size
+            ), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
+            assert (
+                batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
+            ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
+            assert np.all(
+                x_patch_lengths[:, 0] == 1
+            ), f"first patch should always be 1, {x_patch_lengths[:, 0]}"
+            # cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024)
+            # cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024
+            # print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}")
+            truncate_batch(
+                batch,
+                max_length=max_length,
+                pad_id=pad_id,
+                pad_to_max_length=pad_to_max_length,
+                enable_byte_ngrams=enable_byte_ngrams,
+            )
+            yield batch
diff --git a/bytelatent/data/iterators/preprocess_iterator.py b/bytelatent/data/iterators/preprocess_iterator.py
new file mode 100644
index 0000000..8eeba41
--- /dev/null
+++ b/bytelatent/data/iterators/preprocess_iterator.py
@@ -0,0 +1,111 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+from typing import Any, Generator
+
+import torch
+from pydantic import BaseModel, ConfigDict
+
+from bytelatent.data.data_types import BltExample
+from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
+from bytelatent.data.iterators.arrow_iterator import (
+    ArrowFileIterator,
+    ArrowFileIteratorState,
+)
+from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
+from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
+from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
+from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
+
+
+class PreprocessIteratorState(BaseModel, IteratorState):
+    model_config = ConfigDict(extra="forbid")
+    arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
+    add_tokens: bool
+    add_patches: bool
+    tokenizer_args: TokenizerArgs
+    patcher_args: PatcherArgs
+
+    def build(self):
+        arrow_iterator = self.arrow_file_iterator_state.build()
+        return PreprocessIterator(
+            arrow_iterator,
+            patcher_args=self.patcher_args,
+            tokenizer_args=self.tokenizer_args,
+            add_tokens=self.add_tokens,
+            add_patches=self.add_patches,
+        )
+
+
+class PreprocessIterator(StatefulIterator):
+    """
+    Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
+    preprocessing like tokenization and patching
+    """
+
+    def __init__(
+        self,
+        arrow_iterator: ArrowFileIterator,
+        *,
+        patcher_args: PatcherArgs,
+        tokenizer_args: TokenizerArgs,
+        add_tokens: bool = True,
+        add_patches: bool = True,
+    ):
+        self.arrow_iterator = arrow_iterator
+        self.tokenizer_args = tokenizer_args
+        self.patcher_args = patcher_args
+        self.add_tokens = add_tokens
+        self.add_patches = add_patches
+        self.tokenizer: BltTokenizer | None = None
+        self.patcher: Patcher | None = None
+
+    def get_state(self) -> PreprocessIteratorState:
+        """
+        The only state to maintain here is from arrow, there
+        isn't any internal state on this iterator.
+        """
+        return PreprocessIteratorState(
+            arrow_file_iterator_state=self.arrow_iterator.get_state(),
+            tokenizer_args=self.tokenizer_args,
+            patcher_args=self.patcher_args,
+            add_tokens=self.add_tokens,
+            add_patches=self.add_patches,
+        )
+
+    def create_iter(self) -> Generator[BltExample, Any, None]:
+        if self.tokenizer is None and self.add_tokens:
+            self.tokenizer = self.tokenizer_args.build()
+        if self.patcher is None and self.add_patches:
+            self.patcher = self.patcher_args.build()
+
+        example_iter = self.arrow_iterator.create_iter()
+        for example in example_iter:
+            if self.add_tokens:
+                tokens = self.tokenizer.encode(example.text)
+            else:
+                tokens = example.tokens
+            if (
+                self.patcher is not None
+                and self.patcher.patching_mode == PatchingModeEnum.entropy
+            ):
+                assert (
+                    example.entropies is not None
+                ), "For patching, entropies cannot be None"
+                entropies = torch.tensor(example.entropies).unsqueeze(0)
+            else:
+                entropies = None
+            if self.patcher is None:
+                patch_lengths = None
+            else:
+                patch_lengths = self.patcher.patch(
+                    torch.tensor(tokens).unsqueeze(0),
+                    include_next_token=False,
+                    entropies=entropies,
+                )[0][0].tolist()
+            yield BltExample(
+                sample_id=example.sample_id,
+                text=example.text,
+                tokens=tokens,
+                mask=[True] * len(tokens),
+                patch_lengths=patch_lengths,
+                entropies=example.entropies,
+            )
diff --git a/bytelatent/data/iterators/sampling_iterator.py b/bytelatent/data/iterators/sampling_iterator.py
new file mode 100644
index 0000000..6474bf6
--- /dev/null
+++ b/bytelatent/data/iterators/sampling_iterator.py
@@ -0,0 +1,66 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+from typing import Any
+
+import numpy as np
+from pydantic import BaseModel, ConfigDict
+
+from bytelatent.data.iterators.abstract_iterator import StatefulIterator
+from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
+
+
+class SamplingIteratorState(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    rng_state: dict[str, Any]
+    source_to_weight: dict[str, float]
+    source_to_iterator_state: dict[str, SequenceIteratorState]
+
+    def build(self) -> "SamplingIterator":
+        return SamplingIterator(
+            rng_state=self.rng_state,
+            source_to_weight=self.source_to_weight,
+            source_to_iterator={
+                source: state.build()
+                for source, state in self.source_to_iterator_state.items()
+            },
+        )
+
+
+class SamplingIterator(StatefulIterator):
+    def __init__(
+        self,
+        *,
+        rng_state: dict[str, Any],
+        source_to_weight: dict[str, float],
+        source_to_iterator: dict[str, StatefulIterator],
+    ):
+        self.rng = np.random.default_rng()
+        self.rng.bit_generator.state = rng_state
+        self.source_to_weight = source_to_weight
+        self.source_to_iterator = source_to_iterator
+
+    def get_state(self) -> SamplingIteratorState:
+        return SamplingIteratorState(
+            rng_state=self.rng.bit_generator.state,
+            source_to_weight=self.source_to_weight,
+            source_to_iterator_state={
+                source: iterator.get_state()
+                for source, iterator in self.source_to_iterator.items()
+            },
+        )
+
+    def create_iter(self):
+        n_sources = len(self.source_to_weight)
+        possible_sources = []
+        weights = []
+        for source, w in self.source_to_weight.items():
+            possible_sources.append(source)
+            weights.append(w)
+
+        source_to_python_iter = {
+            source: self.source_to_iterator[source].create_iter()
+            for source in possible_sources
+        }
+        while True:
+            norm_weights = np.array(weights) / np.array(weights).sum()
+            source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
+            yield next(source_to_python_iter[source_choice])
diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py
new file mode 100644
index 0000000..14e3747
--- /dev/null
+++ b/bytelatent/data/iterators/sequence_iterator.py
@@ -0,0 +1,122 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+from logging import getLogger
+from typing import Any
+
+import numpy as np
+from pydantic import BaseModel, ConfigDict
+
+from bytelatent.data.data_types import BltSequence
+from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
+from bytelatent.data.iterators.preprocess_iterator import (
+    PreprocessIterator,
+    PreprocessIteratorState,
+)
+
+logger = getLogger()
+
+
+class SequencePackingArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    output_seq_len: int
+    buffer_size: int
+
+
+class SequenceIteratorState(BaseModel, IteratorState):
+    model_config = ConfigDict(extra="forbid")
+    sequence_packing_args: SequencePackingArgs
+    preprocess_iterator_state: PreprocessIteratorState
+    rng_state: dict[str, Any]
+
+    def build(self):
+        preprocess_iterator = self.preprocess_iterator_state.build()
+        return SequenceIterator(
+            preprocess_iterator,
+            sequence_packing_args=self.sequence_packing_args,
+            rng_state=self.rng_state,
+        )
+
+
+class SequenceIterator(StatefulIterator):
+    def __init__(
+        self,
+        preprocess_iterator: PreprocessIterator,
+        *,
+        rng_state: dict[str, Any],
+        sequence_packing_args: SequencePackingArgs,
+    ):
+        self.preprocess_iterator = preprocess_iterator
+        self.sequence_packing_args = sequence_packing_args
+        self.output_seq_len = sequence_packing_args.output_seq_len
+        self.buffer_size = sequence_packing_args.buffer_size
+        self.rng = np.random.default_rng()
+        self.rng.bit_generator.state = rng_state
+
+    def get_state(self):
+        # TODO: need to also perist the current shuffle buffer
+        return SequenceIteratorState(
+            sequence_packing_args=self.sequence_packing_args,
+            preprocess_iterator_state=self.preprocess_iterator.get_state(),
+            rng_state=self.rng.bit_generator.state,
+        )
+
+    def create_iter(self):
+        example_iter = self.preprocess_iterator.create_iter()
+        n_buffer_patches = self.buffer_size * self.output_seq_len
+
+        patch_lengths: list[int] = []
+        tokens: list[int] = []
+        mask: list[bool] = []
+        first = True
+        for example in example_iter:
+            assert example.tokens is not None
+            assert example.mask is not None
+            assert example.patch_lengths is not None
+            assert len(example.tokens) != 0
+            assert len(example.mask) != 0
+            assert len(example.tokens) == len(example.mask)
+            assert len(example.tokens) == sum(example.patch_lengths)
+
+            tokens.extend(example.tokens)
+            mask.extend(example.mask)
+            patch_lengths.extend(example.patch_lengths)
+
+            while len(patch_lengths) >= n_buffer_patches:
+                if first:
+                    first = False
+                    logger.info("First buffer complete")
+
+                x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
+                    self.buffer_size, self.output_seq_len
+                )
+                seq_tokens = []
+                seq_mask = []
+                start_id = 0
+                # We fix the number of patches and therefore global steps per batch
+                # so we have a variable number of tokens we need to account for
+                for num_tokens in x_patches.sum(axis=-1):
+                    seq_tokens.append(tokens[start_id : start_id + num_tokens])
+                    seq_mask.append(mask[start_id : start_id + num_tokens])
+                    start_id += num_tokens
+
+                assert start_id == x_patches.sum()
+
+                # Remove what we just added from the buffer
+                patch_lengths = patch_lengths[n_buffer_patches:]
+                tokens = tokens[x_patches.sum() :]
+                mask = mask[x_patches.sum() :]
+
+                seq_patch_lengths: list[list[int]] = x_patches.tolist()
+                assert len(seq_patch_lengths) == self.buffer_size
+                for idx in self.rng.permutation(len(seq_patch_lengths)):
+                    assert len(seq_patch_lengths[idx]) == self.output_seq_len
+                    assert (
+                        sum(seq_patch_lengths[idx])
+                        == len(seq_tokens[idx])
+                        == len(seq_mask[idx])
+                    ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
+                    assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
+                    yield BltSequence(
+                        tokens=seq_tokens[idx],
+                        mask=seq_mask[idx],
+                        patch_lengths=seq_patch_lengths[idx],
+                    )
diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py
new file mode 100644
index 0000000..4266427
--- /dev/null
+++ b/bytelatent/data/iterators/test_arrow_iterator.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import numpy as np
+import pyarrow as pa
+
+# pyarrow needs the initialization from this import
+import pyarrow.dataset  # pyright: ignore
+
+from bytelatent.constants import BLT_DATA
+from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
+
+ENTROPY_MODEL = "transformer_100m"
+ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
+ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow")
+
+
+def test_basic_arrow_file():
+    dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow")
+    n_head = 1000
+    head_df = dataset.head(n_head).to_pandas()
+
+    initial_state = ArrowFileIteratorState(
+        file_path=None,
+        num_workers=1,
+        worker_id=0,
+        preprocess_dir=None,
+        entropy_model_name=ENTROPY_MODEL,
+        dataset_files=[ARROW_TEST_DATA_1],
+        row_num=0,
+        arrow_batch_size=100,
+    )
+    arrow_file = initial_state.build()
+    start_state = arrow_file.get_state()
+    assert start_state.row_num == initial_state.row_num
+
+    sample_id = None
+    for example in arrow_file.create_iter():
+        sample_id = example.sample_id
+        assert head_df.iloc[0]["sample_id"] == sample_id
+        break
+
+    assert arrow_file.get_state().row_num == 1
+    arrow_file = initial_state.build()
+    for example in arrow_file.create_iter():
+        assert example.sample_id == sample_id
+        assert head_df.iloc[0]["sample_id"] == sample_id
+        break
+
+    # Test resume far enough in to be past the batch size of 100
+    resumed_state = ArrowFileIteratorState(
+        file_path=None,
+        num_workers=1,
+        worker_id=0,
+        preprocess_dir=None,
+        entropy_model_name=ENTROPY_MODEL,
+        dataset_files=[ARROW_TEST_DATA_1],
+        row_num=251,
+        arrow_batch_size=100,
+    )
+    arrow_file = resumed_state.build()
+    for example in arrow_file.create_iter():
+        assert example.sample_id == head_df.iloc[251]["sample_id"]
+        assert arrow_file.get_state().row_num == 252
+        break
+
+    world_rank = 1
+    world_size = 4
+    # Test World Size and Rank
+    rank_state = ArrowFileIteratorState(
+        file_path=None,
+        num_workers=world_size,
+        worker_id=world_rank,
+        preprocess_dir=None,
+        entropy_model_name=ENTROPY_MODEL,
+        dataset_files=[ARROW_TEST_DATA_1],
+        row_num=0,
+        arrow_batch_size=100,
+    )
+    arrow_file = rank_state.build()
+    expected_ids = []
+    for i in range(n_head):
+        if i % world_size == world_rank:
+            expected_ids.append(head_df.iloc[i]["sample_id"])
+    print(len(expected_ids))
+    i = 0
+    for example in arrow_file.create_iter():
+        assert example.sample_id == expected_ids[i]
+        i += 1
+        if i >= len(expected_ids):
+            break
diff --git a/bytelatent/data/iterators/test_iters.py b/bytelatent/data/iterators/test_iters.py
new file mode 100644
index 0000000..9bc9d59
--- /dev/null
+++ b/bytelatent/data/iterators/test_iters.py
@@ -0,0 +1,162 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import pandas as pd
+from pydantic import BaseModel
+
+from bytelatent.constants import BLT_DATA
+from bytelatent.data.data_types import BltExample
+from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
+from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
+from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
+from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
+
+
+class BltTestIteratorState(BaseModel, IteratorState):
+    position: int
+    total: int
+
+    def build(self):
+        blt_iter = BltTestIteratorState(total=self.total)
+        blt_iter.position = self.position
+        return blt_iter
+
+
+class BltTestIterator(StatefulIterator):
+    def __init__(self, total: int):
+        self.position = 0
+        self.total = total
+
+    def get_state(self):
+        return BltTestIteratorState(position=self.position, total=self.total)
+
+    def create_iter(self):
+        for i in range(self.total):
+            self.position += 1
+            yield BltExample(
+                sample_id=f"test_{i}",
+                text=f"This is some test {i} text.",
+                tokens=None,
+                mask=None,
+                entropies=None,
+                patch_lengths=None,
+            )
+
+
+class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
+    position: int
+    total: int
+
+    def build(self):
+        blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
+        blt_iter.position = self.position
+        return blt_iter
+
+
+class BltTestWithEntropiesIterator(StatefulIterator):
+    def __init__(self, total: int):
+        self.position = 0
+        self.total = total
+
+    def get_state(self):
+        return BltTestIteratorState(position=self.position, total=self.total)
+
+    def create_iter(self):
+        text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
+        df = pd.read_json("fixtures/tokens_with_entropies.json")
+        tokens = df["token_ids"].tolist()
+        entropies = df["entropies"].tolist()
+        # BOS and EOS
+        assert len(tokens) == len(text) + 2
+        for i in range(self.total):
+            self.position += 1
+            yield BltExample(
+                sample_id=f"test_{i}",
+                text=text,
+                tokens=tokens,
+                mask=[True] * len(tokens),
+                entropies=entropies,
+                patch_lengths=None,
+            )
+
+
+def test_preprocess_iter():
+    total = 3
+    tokenizer_args = TokenizerArgs(
+        name="blt",
+        init_kwargs={
+            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
+        },
+    )
+    for mode in [
+        PatchingModeEnum.bpe,
+        PatchingModeEnum.space,
+    ]:
+        data_it = BltTestIterator(total)
+        patcher_args = PatcherArgs(patching_mode=mode)
+        example_it = PreprocessIterator(
+            data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
+        )
+        count = 0
+        for example in example_it.create_iter():
+            assert isinstance(example.tokens, list)
+            assert isinstance(example.tokens[0], int)
+            # BOS and EOS
+            assert len(example.tokens) == len(example.text) + 2
+            assert example.mask is not None
+            assert len(example.tokens) == len(example.mask)
+            count += 1
+
+        assert count == total
+
+
+def test_non_entropy_patch_iter():
+    total = 3
+    tokenizer_args = TokenizerArgs(
+        name="blt",
+        init_kwargs={
+            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
+        },
+    )
+    for mode in [
+        PatchingModeEnum.bpe,
+        PatchingModeEnum.space,
+    ]:
+        patcher_args = PatcherArgs(patching_mode=mode)
+        data_it = BltTestIterator(total)
+        example_it = PreprocessIterator(
+            data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
+        )
+
+        count = 0
+        for example in example_it.create_iter():
+            assert isinstance(example.patch_lengths, list)
+            assert isinstance(example.patch_lengths[0], int)
+            assert len(example.tokens) == sum(example.patch_lengths)
+            count += 1
+
+        assert count == total
+
+
+def test_entropy_patch_iter():
+    total = 2
+    patcher_args = PatcherArgs(
+        patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
+    )
+    tokenizer_args = TokenizerArgs(
+        name="blt",
+        init_kwargs={
+            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
+        },
+    )
+    data_it = BltTestWithEntropiesIterator(total)
+    example_it = PreprocessIterator(
+        data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
+    )
+
+    count = 0
+    for example in example_it.create_iter():
+        assert isinstance(example.patch_lengths, list)
+        assert isinstance(example.patch_lengths[0], int)
+        assert len(example.tokens) == sum(example.patch_lengths)
+        count += 1
+
+    assert count == total
diff --git a/bytelatent/data/ngram_processor.py b/bytelatent/data/ngram_processor.py
new file mode 100644
index 0000000..2498183
--- /dev/null
+++ b/bytelatent/data/ngram_processor.py
@@ -0,0 +1,146 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import pickle
+from pathlib import Path
+
+import numpy as np
+
+from bytelatent import ByteLatentError
+
+LOOKUP_OFFSET = 4
+
+
+def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1):
+    """
+    Wrapper function for applying the lookup table to each n-gram.
+
+    :param ngram: Array of numbers representing an n-gram.
+    :param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs.
+    :param lookup_offset: Offset to add to the lookup result.
+    :return: The value associated with the n-gram tuple in the dictionary, or None if not found.
+    """
+
+    def apply_lookup_table(ngram):
+        """
+        Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary.
+
+        :param ngram: Array of numbers representing an n-gram.
+        :return: The value associated with the n-gram tuple in the dictionary, or None if not found.
+        """
+        # Convert the n-gram to a tuple
+        ngram_tuple = tuple(ngram)
+
+        if ngram_tuple not in ngram_to_idx:
+            return 0
+        else:
+            return ngram_to_idx[ngram_tuple] + lookup_offset
+
+    return apply_lookup_table
+
+
+def get_byte_ngrams_ids(
+    byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0
+):
+    """
+    Generate n-grams from a 2D numpy array.
+
+    :param n: The length of each n-gram.
+    :param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams.
+    :return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET.
+    """
+    num_rows, num_cols = byte_array.shape
+
+    # Create an array to hold the padded version of the original array
+    padded_array = np.pad(
+        byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value
+    )
+
+    # Use stride tricks to avoid explicit looping
+    strided = np.lib.stride_tricks.as_strided
+    shape = (num_rows, num_cols, n)
+    strides = padded_array.strides[:2] + (padded_array.strides[1],)
+    ngrams = strided(padded_array, shape=shape, strides=strides)
+
+    ngram_ids = np.apply_along_axis(
+        apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams
+    )
+    assert ngram_ids.shape == byte_array.shape
+    return ngram_ids
+
+
+def reload_tables(
+    ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET
+) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]:
+    """
+    Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram,
+    only load up to the max specified size. Return the actual number of ngrams taken per ngram size.
+    """
+    idx_to_ngram_tables = {}
+    ngram_to_idx_tables = {}
+    vocab_sizes = {}
+    for ngram, size in ngram_to_size.items():
+        with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f:
+            # These are already sorted by count
+            # Value: tuple of: count, ngram, dataset
+            ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[
+                "counts"
+            ]
+            table = [ngram for ngram, _ in ngram_data][:size]
+            if len(table) != size:
+                raise ValueError(
+                    f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}"
+                )
+            ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)}
+            actual_size = len(table)
+            idx_to_ngram_tables[ngram] = table
+            ngram_to_idx_tables[ngram] = ngram_to_idx
+            vocab_sizes[ngram] = actual_size + offset
+    return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes
+
+
+def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
+    if ngram_to_size_str is None:
+        return None
+    ngram_to_size = {}
+    for entry in ngram_to_size_str.split(","):
+        ngram, size = entry.split(":")
+        ngram = int(ngram)
+        size = int(size)
+        ngram_to_size[ngram] = size
+    return ngram_to_size
+
+
+class NgramProcessor:
+    def __init__(
+        self,
+        ngram_table_dir: str | None = None,
+        ngram_to_size: dict[int, int] | None = None,
+    ):
+        if ngram_table_dir is None or ngram_to_size is None:
+            raise ByteLatentError(
+                "ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True"
+            )
+        (
+            self.ngram_to_idx_tables,
+            self.idx_to_ngram_tables,
+            self.ngram_vocab_sizes,
+        ) = reload_tables(ngram_table_dir, ngram_to_size)
+        # Lowest to highest ngram
+        self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys()))
+        # Although the model might not use all the ngrams, we need the tokenizer
+        # to produce ngram_ids such that index zero is the 2-gram, later on in
+        # src.model.megabyte.Megabyte.forward
+        assert self.ngram_sizes[0] == 2
+
+    def encode_single_ngram_table(self, data: np.ndarray, n: int):
+        """
+        Return the n-grams of the input data for a given n
+        numpy array with ids of shape data.shape
+        """
+        return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0)
+
+    def encode_token_ngrams(self, data: np.ndarray):
+        """
+        Return the n-grams of the input data.
+        output shape: [ids with data.shape for n in self.ngram_sizes]
+        """
+        return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes]
diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py
new file mode 100644
index 0000000..ede8b06
--- /dev/null
+++ b/bytelatent/data/patcher.py
@@ -0,0 +1,609 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import math
+import time
+from collections import defaultdict
+from enum import Enum
+
+import torch
+from pydantic import BaseModel
+from torch.nn import functional as F
+
+from bytelatent.distributed import get_local_rank
+from bytelatent.entropy_model import load_entropy_model
+
+# from src.slurm import get_local_rank
+from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
+from bytelatent.tokenizers.constants import BPE_ID, OFFSET
+
+
+class PatchingModeEnum(str, Enum):
+    entropy = "entropy"
+    bpe = "bpe"
+    bpe_patcher = "bpe_patcher"
+    space = "space"
+
+
+class PatcherArgs(BaseModel):
+    patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
+    patching_device: str = "cuda"
+    entropy_model_checkpoint_dir: str | None = None
+    realtime_patching: bool = False
+    threshold: float = 1.335442066192627
+    threshold_add: float | None = None
+    max_patch_length: int | None = None
+    patch_size: float = 4.5
+    patching_batch_size: int = 1
+    data_loader_patching: bool = False
+    device: str = "cuda"
+    monotonicity: bool = False
+    log_time: bool = False
+
+    def build(self) -> "Patcher":
+        return Patcher(self)
+
+
+def entropy(scores):
+    """
+    scores: [bs, seq_len, vocab]
+    returns [bs, seq_len]
+
+    Computes the entropy for each token in the batch.
+    Note: uses natural log.
+    """
+    log_probs = F.log_softmax(scores, dim=-1)
+    probs = torch.exp(log_probs)
+    p_log_p = log_probs * probs
+    entropy = -p_log_p.sum(dim=-1)
+    return entropy
+
+
+def calculate_entropies(
+    tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
+):
+    """
+    tokens: 2D tensor of shape [batch_size, seq_len]
+    Return 2D tensor of shape [batch_size, seq_len] with entropies for each token.
+
+    Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
+    Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
+    """
+    with torch.no_grad():
+        entropies = []
+        max_length = getattr(entropy_model, "max_length", 8192)
+        batch_numel = max_length * patching_batch_size
+        splits = torch.split(tokens.flatten(), batch_numel)
+        for split in splits:
+            pad_size = (max_length - (split.numel() % max_length)) % max_length
+            pad = torch.zeros(
+                pad_size, dtype=split.dtype, device=split.device, requires_grad=False
+            )
+            split = torch.cat((split, pad), dim=0)
+            split = split.reshape(-1, max_length)
+            if device is not None:
+                split = split.to(device)
+            assert torch.all(split >= 0) and torch.all(split < 260)
+            pred, _ = entropy_model(split)
+            pred = pred.reshape(-1, pred.shape[-1])[
+                : split.numel() - pad_size, :
+            ]  # [batch_size * seq_len, vocab]
+            pred_entropies = entropy(pred)
+            entropies.append(pred_entropies)
+
+        entropies = torch.cat(entropies, dim=0)
+        entropies = entropies.reshape(tokens.shape)
+    return entropies
+
+
+def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
+    """
+    entropies: [bs, seq_len] torch tensor of entropies
+    t: threshold
+    returns [bs, seq_len] mask where True indicates the start of a patch
+    """
+    bs, seq_len = entropies.shape
+    mask = torch.zeros_like(entropies, dtype=torch.bool)
+    mask[:, 0] = True
+
+    # Calculate differences between consecutive elements along the sequence length
+    differences = entropies[:, 1:] - entropies[:, :-1]
+
+    # Calculate conditions for all elements except the first one in each sequence
+    condition = differences > t
+
+    # Update the mask based on the condition
+    mask[:, 1:] = condition
+
+    return mask
+
+
+def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
+    """
+    entropies: [bs, seq_len] torch tensor of entropies
+    t: threshold
+    returns [bs, seq_len] mask where True indicates the start of a patch
+    """
+    bs, seq_len = entropies.shape
+    mask = torch.zeros_like(entropies, dtype=torch.bool)
+    mask[:, 0] = True
+
+    # Calculate differences between consecutive elements along the sequence length
+    differences = entropies[:, 1:] - entropies[:, :-1]
+
+    # Calculate conditions for all elements except the first one in each sequence
+    condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1])
+
+    # Update the mask based on the condition
+    mask[:, 1:] = condition
+
+    return mask
+
+
+def patch_start_ids_from_patch_start_mask(patch_start_mask):
+    bs, trunc_seq_len = patch_start_mask.shape
+    max_patches = patch_start_mask.sum(dim=1).max()
+    if max_patches == 0:
+        patch_start_ids = torch.full(
+            (bs, trunc_seq_len),
+            trunc_seq_len,
+            dtype=torch.long,
+            device=patch_start_mask.device,
+        )
+    else:
+        patch_ids = (
+            torch.arange(trunc_seq_len, device=patch_start_mask.device)
+            .unsqueeze(0)
+            .repeat(bs, 1)
+        )
+        extra_patch_ids = torch.full(
+            (bs, trunc_seq_len),
+            trunc_seq_len,
+            dtype=torch.long,
+            device=patch_start_mask.device,
+        )
+        all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
+        patch_start_mask_padded = torch.cat(
+            (patch_start_mask, ~patch_start_mask), dim=1
+        )
+        patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
+            bs, trunc_seq_len
+        )[:, :max_patches]
+    return patch_start_ids
+
+
+def check_non_zero_after_zero(tensor):
+    zero_mask = tensor == 0
+    shifted_mask = torch.cat(
+        [
+            torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
+            zero_mask[:, :-1],
+        ],
+        dim=1,
+    )
+    non_zero_after_zero = (tensor != 0) & shifted_mask
+    return non_zero_after_zero.any()
+
+
+def patch_lengths_from_start_ids(patch_start_ids, seq_len):
+    """
+    Calculate patch lengths from start ids.
+    start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
+        the rest are filled to the seq len.
+    seq_len: ex: 7 length of the sequence
+
+    returns the patch lengths:
+    [1, 6] for the above example.
+    """
+    last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
+    patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
+    patch_lengths = patch_end_ids - patch_start_ids + 1
+    assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
+    assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
+    return patch_lengths
+
+
+def find_space_patch_start_ids(tokens):
+    bs, seq_len = tokens.shape
+    tokens_no_offset = tokens - OFFSET
+    patch_end_mask = (
+        (tokens_no_offset < ord("0"))
+        | ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A")))
+        | ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a")))
+        | ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000))
+        | (0b1100_0000 <= tokens_no_offset)
+    )
+    patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not()
+    patch_end_mask |= tokens < OFFSET
+
+    patch_start_mask = torch.cat(
+        [
+            torch.tensor([1, 1], device=tokens.device, dtype=torch.bool)
+            .unsqueeze(0)
+            .repeat(bs, 1),
+            patch_end_mask[:, 1:],
+        ],
+        dim=1,
+    )
+    max_patches = patch_start_mask.sum(dim=1).max()
+
+    patch_ids = (
+        torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1)
+    )
+    extra_patch_ids = torch.full(
+        (bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device
+    )
+    all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
+    patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
+
+    patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[
+        :, :max_patches
+    ]
+    return patch_start_ids
+
+
+def to_device(entropy_model, device=None):
+    if device == "cuda":
+        rank = get_local_rank()
+        device = f"cuda:{rank}"
+    entropy_model = entropy_model.to(device)
+    return entropy_model, device
+
+
+def model_pred_to_bpe_patching_pred(pred):
+    _, indices = torch.max(pred, dim=1)
+    return indices == BPE_ID
+
+
+def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None):
+    assert tokens.device == torch.device(
+        "cpu"
+    ), f"{tokens.device} != cpu expects tokens to be on cpu"
+    with torch.no_grad():
+        bpe_patcher_device, device = to_device(
+            bpe_patcher, device
+        )  # Get entropy model to right rank device.
+        bpe_patching_mask = []
+        max_length = getattr(bpe_patcher, "max_length", 8192)
+        batch_numel = max_length * patching_batch_size
+        splits = torch.split(tokens.flatten(), batch_numel)
+        for split in splits:
+            pad_size = (max_length - (split.numel() % max_length)) % max_length
+            pad = torch.zeros(
+                pad_size, dtype=split.dtype, device=split.device, requires_grad=False
+            )
+            split = torch.cat((split, pad), dim=0)
+            split = split.reshape(-1, max_length).to(device)
+            assert torch.all(split >= 0) and torch.all(split < 260)
+            pred = bpe_patcher_device(split)
+            pred_cpu = pred[0].cpu()
+            pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[
+                : split.numel() - pad_size, :
+            ]  # [batch_size * seq_len, vocab]
+            bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu)
+            bpe_patching_mask.append(bpe_patching_pred)
+        bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0)
+        bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape)
+    return bpe_patching_mask
+
+
+def find_bpe_patcher_patch_start_ids(
+    tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True
+):
+    bs, seq_len = tokens.shape
+
+    first_ids = (
+        torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
+        .unsqueeze(0)
+        .repeat(bs, 1)
+    )
+    preds_truncation_len = first_ids.shape[1]
+    token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1]
+    if token_input.shape[1] >= 1:
+        patch_start_mask = apply_bpe_patcher(
+            token_input, bpe_patcher, patching_batch_size, device
+        )
+        assert (
+            patch_start_mask.shape[1]
+            == tokens.shape[1] + include_next_token - preds_truncation_len
+        ), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
+        patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
+        patch_start_ids = torch.cat(
+            (first_ids, patch_start_ids + preds_truncation_len), dim=1
+        )
+    else:
+        patch_start_ids = first_ids
+    return patch_start_ids
+
+
+def find_entropy_patch_start_ids(
+    entropies,
+    patch_size=None,
+    threshold=None,
+    threshold_add=None,
+    monotonicity=False,
+    include_next_token=True,
+):
+    """
+    Use entropies to find the start ids of each patch.
+    Use patch_size or threshold to figure out the total number of patches to allocate.
+
+    When threshold is not None the number of patches is not constant between
+    different sequences, but patches can be identified incrementally rather than
+    decided globally using the entire sequence.
+    """
+    bs, seq_len = entropies.shape[:2]
+
+    first_ids = (
+        torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
+        .unsqueeze(0)
+        .repeat(bs, 1)
+    )
+    preds_truncation_len = first_ids.shape[
+        1
+    ]  # remove the first preds because they will be start of patches.
+    entropies = entropies[:, 1:]
+    if threshold is None:
+        num_patches = seq_len // patch_size
+        patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
+        patch_start_ids = patch_start_ids.sort(dim=1).values
+    else:
+        # Assumes that there is at least one token going over the threshold
+        if monotonicity:
+            patch_start_mask = patch_start_mask_from_entropy_with_monotonicity(
+                entropies, threshold
+            )
+        elif threshold_add is not None and threshold is not None:
+            patch_start_mask = patch_start_mask_global_and_monotonicity(
+                entropies, threshold, threshold_add
+            )
+        else:
+            patch_start_mask = entropies > threshold
+        if not include_next_token:
+            patch_start_mask = patch_start_mask[:, :-1]
+        # patch_start_mask[1:] |= tokens[:-1] < OFFSET
+        patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
+
+    patch_start_ids = torch.cat(
+        (first_ids, patch_start_ids + preds_truncation_len), dim=1
+    )
+    return patch_start_ids
+
+
+def rightpad(seq, pad_id, max_len):
+    return seq + [pad_id] * (max_len - len(seq))
+
+
+def find_bpe_delim_patch_start_ids(tokens, delim):
+    ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False)
+    out = [[0, 1] for _ in range(tokens.shape[0])]
+    for x, y in ids:
+        # start is at delim + 1, delim should be the last element in the patch.
+        out[x.item()].append(y.item() + 1)
+    max_len = max([len(elt) for elt in out])
+    out = [rightpad(elt, tokens.shape[1], max_len) for elt in out]
+    patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device)
+    return patch_start_ids
+
+
+def find_lookup_table_start_mask(
+    tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
+):
+    window_size = lookup_table.ndim
+    # Unfold the tensor to get sliding windows
+    unfolded = tokens.unfold(1, window_size, 1)
+    # Gather indices for each dimension
+    indices = [unfolded[..., i] for i in range(window_size)]
+    # Access the lookup table using the gathered indices
+    result = lookup_table[indices]
+    return result
+
+
+def find_lookup_table_patch_start_ids(
+    tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
+):
+    bs, seq_len = tokens.shape
+
+    first_ids = (
+        torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
+        .unsqueeze(0)
+        .repeat(bs, 1)
+    )
+    preds_truncation_len = first_ids.shape[1]
+    window_size = lookup_table.ndim
+    assert window_size == 2, f"{window_size} != 2"
+    # output dimensions: token_input shape - window_size + 1   --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape
+    token_input = (
+        tokens if include_next_token else tokens[:, : -preds_truncation_len + 1]
+    )
+    if token_input.shape[1] >= window_size:
+        patch_start_mask = find_lookup_table_start_mask(
+            token_input, lookup_table, include_next_token
+        )
+        assert (
+            patch_start_mask.shape[1]
+            == tokens.shape[1] + include_next_token - preds_truncation_len
+        ), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
+        patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
+        patch_start_ids = torch.cat(
+            (first_ids, patch_start_ids + preds_truncation_len), dim=1
+        )
+    else:
+        patch_start_ids = first_ids
+    return patch_start_ids
+
+
+def split_large_numbers(lst, m):
+    new_lst = []
+    for i in lst:
+        if i > m:
+            while i > m:
+                new_lst.append(m)
+                i -= m
+            new_lst.append(i)
+        else:
+            new_lst.append(i)
+    assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
+    return new_lst
+
+
+class Patcher:
+    def __init__(self, patcher_args: PatcherArgs):
+        self.patcher_args = patcher_args
+        self.patching_mode = patcher_args.patching_mode
+        self.realtime_patching = patcher_args.realtime_patching
+        if self.realtime_patching:
+            assert (
+                patcher_args.entropy_model_checkpoint_dir is not None
+            ), "Cannot require realtime patching without an entropy model checkpoint"
+            entropy_model = load_entropy_model(
+                patcher_args.entropy_model_checkpoint_dir
+            )
+            entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
+            self.entropy_model = entropy_model
+        else:
+            self.entropy_model = None
+        self.threshold = patcher_args.threshold
+        self.threshold_add = patcher_args.threshold_add
+        self.max_patch_length = patcher_args.max_patch_length
+        self.patch_size = patcher_args.patch_size
+        self.patching_batch_size = patcher_args.patching_batch_size
+        self.data_loader_patching = patcher_args.data_loader_patching
+        self.device = patcher_args.device
+        self.monotonicity = patcher_args.monotonicity
+        self.log_time = patcher_args.log_time
+        if self.log_time:
+            self.log = defaultdict(float)
+
+    def patch(
+        self,
+        tokens: torch.Tensor,
+        include_next_token: bool = False,
+        preds: torch.Tensor | None = None,
+        entropies: torch.Tensor | None = None,
+        threshold: float = None,
+    ) -> torch.Tensor:
+        """
+        tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
+        Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
+        -> output tensor: [batch_size, max_num_patches]
+            each tensor is processed independently and gets right padded with zeros.
+
+        Patching with the following modes:
+        1. patching_mode = None: static patch size
+        2. patching_mode = "entropy":
+            calculate entropy of each token, allocate patches so that the total
+            number of patches is the same as static patching but choose to begin
+            patches on tokens where the model is most uncertain (highest entropy).
+
+            When threshold is provided, it uses the threshold to decide when to
+            start a new patch.
+        3. patching_mode = "space":
+            use space like tokens to define the patches.
+        4. patching_mode = "bpe":
+            use bpe delim tokens to define the patches.
+
+        To correctly patch the last token, it may be necessary to include the next token in the patch
+        lengths calculations. This is controlled by the include_next_token argument.
+        """
+        bs, seq_len = tokens.shape
+        seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
+        scores = None
+        # STATIC
+        if self.patching_mode is None:
+            patch_lengths = torch.zeros(
+                (bs, math.ceil(seq_len_next_tok / self.patch_size)),
+                dtype=tokens.dtype,
+                device=tokens.device,
+            ).fill_(self.patch_size)
+            if seq_len_next_tok % self.patch_size != 0:
+                patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
+        # ENTROPY
+        elif self.patching_mode == PatchingModeEnum.entropy:
+            if self.log_time:
+                s = time.time()
+            if entropies is not None:
+                scores = torch.tensor(entropies, dtype=torch.float32)
+            elif preds is not None:
+                scores = entropy(preds)
+            else:
+                start_entropies = time.time()
+                scores = calculate_entropies(
+                    tokens,
+                    self.entropy_model,
+                    self.patching_batch_size,
+                    self.device,
+                )
+            if self.log_time:
+                self.log["calculate_entropies"] += time.time() - s
+                s = time.time()
+            patch_start_ids = find_entropy_patch_start_ids(
+                scores,
+                self.patch_size,
+                include_next_token=include_next_token,
+                threshold=threshold if threshold is not None else self.threshold,
+                threshold_add=self.threshold_add,
+                monotonicity=self.monotonicity,
+            )
+            if self.log_time:
+                self.log["find_entropy_patch_start_ids"] += time.time() - s
+                s = time.time()
+            patch_lengths = patch_lengths_from_start_ids(
+                patch_start_ids, seq_len_next_tok
+            )
+            if self.log_time:
+                self.log["patch_lengths_from_start_ids"] += time.time() - s
+                s = time.time()
+        # BPE
+        elif self.patching_mode == PatchingModeEnum.bpe:
+            patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID)
+            patch_lengths = patch_lengths_from_start_ids(
+                patch_start_ids, seq_len_next_tok
+            )
+        elif self.patching_mode == PatchingModeEnum.bpe_patcher:
+            patch_start_ids = find_bpe_patcher_patch_start_ids(
+                tokens,
+                self.entropy_model,
+                self.patching_batch_size,
+                self.device,
+                include_next_token,
+            )
+            patch_lengths = patch_lengths_from_start_ids(
+                patch_start_ids, seq_len_next_tok
+            )
+        # SPACE
+        elif self.patching_mode == PatchingModeEnum.space:
+            patch_start_ids = find_space_patch_start_ids(tokens)
+            patch_lengths = patch_lengths_from_start_ids(
+                patch_start_ids, seq_len_next_tok
+            )
+        else:
+            raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
+
+        # Apply any processing to patch lengths
+        if self.max_patch_length is not None:
+            # TODO: avoid going back to a list here.
+            patch_lengths = [
+                split_large_numbers(pl, self.max_patch_length)
+                for pl in patch_lengths.tolist()
+            ]
+            max_len = max([len(pl) for pl in patch_lengths])
+            patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
+            patch_lengths = torch.tensor(
+                patch_lengths, dtype=tokens.dtype, device=tokens.device
+            )
+        assert not check_non_zero_after_zero(patch_lengths)
+        # Find the last non-zero column index using argmax on a reversed version of the tensor
+        last_non_zero_col_reversed = (
+            (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
+        )
+        # Slice the tensor up to the last non-zero column
+        patch_lengths = patch_lengths[
+            :, : patch_lengths.shape[1] - last_non_zero_col_reversed
+        ]
+        assert (
+            torch.sum(patch_lengths)
+            == tokens.numel() + include_next_token * tokens.shape[0]
+        ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
+        if self.log_time:
+            self.log["postprocessing_patch_lengths"] += time.time() - s
+            self.log["tokens"] += patch_lengths.sum().item()
+        return patch_lengths, scores
diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py
new file mode 100644
index 0000000..fadce45
--- /dev/null
+++ b/bytelatent/distributed.py
@@ -0,0 +1,478 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import atexit
+import contextlib
+import logging
+import multiprocessing as mp
+import os
+import random
+import shutil
+import signal
+import socket
+import subprocess
+import sys
+import tempfile
+from dataclasses import asdict, dataclass
+from functools import lru_cache, partial, reduce
+from itertools import chain
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+# for no recompute ops
+import xformers.ops
+from pydantic import BaseModel, ConfigDict
+from torch import distributed as dist
+from torch.distributed import ReduceOp
+from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
+from torch.distributed._tensor import DTensor
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+    checkpoint_wrapper,
+)
+from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.checkpoint import (
+    CheckpointPolicy,
+    create_selective_checkpoint_contexts,
+)
+
+from bytelatent.float8 import convert_linears_to_fp8
+
+logger = logging.getLogger()
+
+# for selective AC
+default_no_recompute_ops = {
+    torch.ops.aten.mm.default,
+    torch.ops.aten._scaled_mm.default,
+    torch.ops.aten._scaled_dot_product_efficient_attention.default,
+    torch.ops.aten._scaled_dot_product_flash_attention.default,
+    torch.ops.c10d_functional.reduce_scatter_tensor.default,
+    torch.ops.xformers_flash.flash_fwd.default,
+    torch.ops.xformers.efficient_attention_forward_cutlass.default,
+}
+
+
+class DistributedArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    dp_shard: int = (
+        1  # In how many shard to split the model weight. Typically number gpu in a node.
+    )
+    dp_replicate: int = (
+        1  # How many times to replicate the model weight. Typically number of nodes.
+    )
+    tp_size: int = 1
+    selective_activation_checkpointing: bool = False
+    compile: bool = False
+    fsdp_type: str = "no_shard"
+    model_dtype: str = "bf16"
+    float8_recipe: str | None = None
+    float8_filter: str = r"layers\.[0-9]+\."
+
+    matmul_allow_tf32: bool = False
+    allow_bf16_reduced_precision_reduction: bool = True
+    detect_anomaly: bool = False
+
+    compile_cache_size_limit: int = 8
+
+    spawn_method: str = "forkserver"
+
+
+class EnvironmentArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    # Use GNU openMP (GOMP) instead of Intel OpenMP [Intel Math Kernel Library (MKL)]
+    MKL_SERVICE_FORCE_INTEL: str = "GNU"
+    OMP_NUM_THREADS: str = "1"
+    MKL_NUM_THREADS: str = "1"
+    # faster intra-node collectives, seems to be a cluster specific flag
+    ENABLE_INTRA_NODE_COMM: str = "1"
+    # avoids OOMs with long context
+    TORCH_NCCL_AVOID_RECORD_STREAMS: str = "1"
+    # increasing NCCL timeout time before having some NCCL error 22 should give a 16s timeout
+    NCCL_IB_TIMEOUT: str = "22"
+    NCCL_DEBUG: str = "INFO"
+    TORCH_NCCL_ASYNC_ERROR_HANDLING: str = "1"
+
+
+def get_device_mesh(distributed_args: DistributedArgs):
+    tp_size = distributed_args.tp_size
+    dp_replicate = distributed_args.dp_replicate
+    dp_shard = distributed_args.dp_shard
+
+    assert (
+        dp_replicate * dp_shard * tp_size == get_world_size()
+    ), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
+
+    dims = []
+    names = []
+    if dp_replicate >= 1:
+        dims.append(dp_replicate)
+        names.append("dp_replicate")
+    if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
+        dims.append(dp_shard)
+        names.append("dp_shard")
+    if tp_size > 1:
+        dims.append(tp_size)
+        names.append("tp")
+    dims = tuple(dims)
+    names = tuple(names)
+
+    return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
+
+
+def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
+    tensor = torch.tensor(x).cuda()
+    dist.all_reduce(tensor, op=ReduceOp.MAX, group=mesh.get_group() if mesh else None)
+    return tensor
+
+
+def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
+    tensor = torch.tensor(x).cuda()
+    dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
+    return tensor
+
+
+def dist_mean_dict(x):
+    r = dict()
+    for k in x:
+        r[k] = dist_mean(x[k])
+        r[k] = r[k].item() if (r[k].dim() == 0) else r[k].tolist()
+    return r
+
+
+@lru_cache()
+def get_is_torch_run() -> bool:
+    return os.environ.get("LOCAL_RANK") is not None
+
+
+@lru_cache()
+def get_is_slurm_job() -> bool:
+    return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
+
+
+@lru_cache()
+def get_global_rank() -> int:
+    if get_is_torch_run():
+        return int(os.environ["RANK"])
+    elif get_is_slurm_job():
+        return int(os.environ["SLURM_PROCID"])
+    else:
+        return 0
+
+
+@lru_cache()
+def get_local_rank() -> int:
+    if get_is_torch_run():
+        return int(os.environ["LOCAL_RANK"])
+    elif get_is_slurm_job():
+        return int(os.environ["SLURM_LOCALID"])
+    else:
+        return 0
+
+
+@lru_cache()
+def get_world_size() -> int:
+    if get_is_torch_run():
+        return int(os.environ["WORLD_SIZE"])
+    elif get_is_slurm_job():
+        return int(os.environ["SLURM_NTASKS"])
+    else:
+        return 1
+
+
+@lru_cache()
+def get_is_master() -> bool:
+    return get_global_rank() == 0
+
+
+@lru_cache()
+def get_master_port(job_id: int) -> int:
+    if get_is_torch_run():
+        return int(os.environ["MASTER_PORT"])
+    else:
+        MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000)
+        rng = random.Random(job_id)
+        return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
+
+
+@lru_cache()
+def get_master_addr() -> str:
+    if get_is_torch_run():
+        return os.environ["MASTER_ADDR"]
+    elif get_is_slurm_job():
+        hostnames = subprocess.check_output(
+            ["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
+        )
+        return hostnames.split()[0].decode("utf-8")
+    else:
+        return "127.0.0.1"
+
+
+def setup_env(env_args: EnvironmentArgs):
+    env_vars = env_args.model_dump()
+
+    # When using Triton, it attempts to locate prebuilt kernels in a cache
+    # located at ~/.triton/cache, but when that's backed by NFS this can fail
+    # with a "OSError: [Errno 116] Stale file handle" error. If we were to set
+    # it to a local directory it would belong to the first user who created it
+    # and it would fail for the job of any other successive user assigned to
+    # that machine. To avoid all this mess we use a temporary per-process cache.
+    triton_cache_dir = tempfile.mkdtemp()
+    atexit.register(shutil.rmtree, triton_cache_dir, ignore_errors=True)
+    env_vars["TRITON_CACHE_DIR"] = triton_cache_dir
+
+    # We change the tmp dir to /scratch in case it's slurm job
+    # This avoids filling up the host's usually limited tmpfs
+    # A full tmpfs leads to very slow creation of processes and weird bugs
+    if get_is_slurm_job():
+        new_tmp = f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}"
+        if os.path.exists(new_tmp):
+            env_vars["TMP_DIR"] = new_tmp
+
+    for name, value in env_vars.items():
+        if os.environ.get(name) != str(value):
+            os.environ[name] = str(value)
+            logger.warning(f"WARNING: Setting {name} to {value}")
+
+
+def setup_torch_distributed(dist_args):
+    """
+    Handle single and multi-GPU / multi-node / SLURM jobs.
+    Initialize the following variables:
+        - global_rank
+        - world_size
+    """
+    mp.set_start_method(dist_args.spawn_method)
+    with mp.Manager():
+        pass
+
+    local_rank = get_local_rank()
+
+    os.environ["RANK"] = str(get_global_rank())
+    os.environ["WORLD_SIZE"] = str(get_world_size())
+    os.environ["MASTER_ADDR"] = get_master_addr()
+    os.environ["MASTER_PORT"] = str(
+        get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1)))
+    )
+
+    if get_is_torch_run():
+        logger.info(f"Run launched with torchrun, local rank: {local_rank}")
+    elif get_is_slurm_job():
+        logger.info(f"Run launched with slurm, local rank: {local_rank}")
+    else:
+        logger.info("Single GPU job")
+
+    logger.info(f"ENV: {os.environ}")
+
+    # set GPU device
+    assert 0 <= local_rank < 8
+    if dist_args.matmul_allow_tf32:
+        torch.backends.cuda.matmul.allow_tf32 = True
+        logger.warning(
+            f"WARNING: Setting torch.backends.matmul.allow_tf32 to True. This is faster but less accurate."
+        )
+    torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
+        dist_args.allow_bf16_reduced_precision_reduction
+    )
+    if torch.cuda.device_count() > 1:
+        torch.cuda.set_device(local_rank)
+    torch.distributed.init_process_group(init_method="env://", backend="nccl")
+    torch.autograd.set_detect_anomaly(dist_args.detect_anomaly)
+
+
+def get_module(module, access_string):
+    names = access_string.split(sep=".")
+    return reduce(getattr, names, module)
+
+
+def set_module(module, access_string, value):
+    names = access_string.split(sep=".")
+    parent = reduce(getattr, names[:-1], module)
+    setattr(parent, names[-1], value)
+
+
+def default_fsdp_grouping_plan(n_layers: int) -> List[Tuple[str, bool]]:
+    return [(f"layers.{i}", i < n_layers - 1) for i in range(n_layers)]
+
+
+def get_default_policy(no_recompute_ops=None):
+    no_recompute_ops = no_recompute_ops or default_no_recompute_ops
+
+    def default_policy(ctx, func, *args, **kwargs):
+        return (
+            CheckpointPolicy.MUST_SAVE
+            if func in no_recompute_ops
+            else CheckpointPolicy.PREFER_RECOMPUTE
+        )
+
+    return default_policy
+
+
+@torch.no_grad()
+def check_model_value_range(
+    model: torch.nn.Module, range: float = 1e3, std: float = 1e3
+):
+    for name, param in chain(model.named_parameters(), model.named_buffers()):
+        if isinstance(param, DTensor):
+            param = param.to_local()
+
+        if param.numel() == 0:
+            logger.warning(
+                f"Model parameter {name} is empty, probably because of FSDP sharding"
+            )
+            continue
+
+        if torch.isnan(param).any() or torch.isinf(param).any():
+            logger.warning(f"Model parameter {name} contains NaN or Inf")
+
+        param_range = param.max() - param.min()
+        param_std = param.std()
+        if param_range > range:
+            logger.warning(
+                f"Model parameter {name} has a suspiciously large range ({param_range}): please check initialization and init_weights is defined and called"
+            )
+        if param_std > std:
+            logger.warning(
+                f"Model parameter {name} has a suspiciously large standard deviation ({param_std}): please check initialization and init_weights is defined and called"
+            )
+        if (param == 0).all():
+            logger.warning(
+                f"Model parameter {name} is all zeros: it might be because of a missing initialization"
+            )
+
+
+def init_signal_handler(callable):
+    """
+    Handle signals sent by SLURM for time limit / pre-emption.
+    """
+    signal.signal(signal.SIGUSR2, callable)
+    logger.warning("Signal handler installed.")
+
+
+def requeue_slurm_job():
+    prod_id = int(os.environ["SLURM_PROCID"])
+    logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
+    if prod_id == 0 and os.environ.get("LAUNCH_WITH", "") != "DORA":
+        logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"])
+        os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
+    else:
+        logger.warning("Not the master process, no need to requeue.")
+    sys.exit(0)
+
+
+@contextlib.contextmanager
+def clean_env():
+    distrib_names = (
+        "MASTER_ADDR",
+        "MASTER_PORT",
+        "RANK",
+        "WORLD_SIZE",
+        "LOCAL_RANK",
+        "LOCAL_WORLD_SIZE",
+        "TORCHELASTIC_RUN_ID",
+        "DORA_FORCE_DISTRIB",
+    )
+    cluster_env = {
+        x: os.environ.pop(x)
+        for x in os.environ
+        if x.startswith(
+            ("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_", "WANDB_")
+        )
+        or x in distrib_names
+    }
+    try:
+        yield
+    finally:
+        os.environ.update(cluster_env)
+
+
+def parallelize_model(
+    model,
+    device_mesh,
+    model_args,
+    distributed_args: DistributedArgs,
+    fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
+    tp_parallelize=None,
+    no_recompute_ops=None,
+):
+    if distributed_args.tp_size > 1:
+        assert (
+            distributed_args.fsdp_type == "full_shard"
+        ), "Only full shard is supported for TP parallelism"
+        assert tp_parallelize is not None, "TP plan is required for TP parallelism"
+        assert (
+            distributed_args.compile == False
+        ), "Compile is not supported for TP parallelism"
+
+        tp_parallelize(model, device_mesh["tp"], model_args, distributed_args)
+
+    if distributed_args.float8_recipe is not None:
+        if distributed_args.tp_size > 1:
+            raise RuntimeError("float8 is incompatible with tensor-parallelism for now")
+        model = convert_linears_to_fp8(
+            model, distributed_args.float8_recipe, distributed_args.float8_filter
+        )
+
+    param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
+        distributed_args.model_dtype
+    ]
+    if (
+        distributed_args.fsdp_type == "full_shard"
+        or distributed_args.fsdp_type == "no_shard"
+    ):
+        if distributed_args.fsdp_type == "no_shard":
+            assert (
+                distributed_args.dp_shard == 1
+            ), "dp_shard must be 1 for no_shard fsdp_type"
+            assert (
+                device_mesh["dp_shard"].size() == 1
+            ), "dp_shard must be 1 for no_shard fsdp_type"
+
+        fsdp_config = dict(
+            mp_policy=(
+                MixedPrecisionPolicy(
+                    param_dtype=param_dtype,
+                    reduce_dtype=torch.float32,
+                )
+            ),
+            mesh=(
+                device_mesh["dp_replicate", "dp_shard"]
+                if distributed_args.dp_shard > 1
+                or distributed_args.fsdp_type == "no_shard"
+                else device_mesh["dp_replicate"]
+            ),
+        )
+
+        if fsdp_grouping_plan is None:
+            # Assume that the model has list of layers and group around it
+            fsdp_grouping_plan = default_fsdp_grouping_plan(len(model.layers))
+
+        for path, reshard_after_forward in fsdp_grouping_plan:
+            module = get_module(model, path)
+            set_module(
+                model,
+                path,
+                fully_shard(
+                    module, **fsdp_config, reshard_after_forward=reshard_after_forward
+                ),
+            )
+
+        model = fully_shard(model, **fsdp_config, reshard_after_forward=True)
+    else:
+        raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
+
+    if distributed_args.selective_activation_checkpointing:
+        model = checkpoint_wrapper(
+            model,
+            context_fn=partial(
+                create_selective_checkpoint_contexts,
+                get_default_policy(no_recompute_ops),
+            ),
+        )
+
+    if distributed_args.compile:
+        torch._dynamo.config.cache_size_limit = (
+            distributed_args.compile_cache_size_limit
+        )
+        model = torch.compile(model)
+
+    return model
diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py
new file mode 100644
index 0000000..1bd1766
--- /dev/null
+++ b/bytelatent/entropy_model.py
@@ -0,0 +1,36 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import json
+import os
+import re
+
+import torch
+
+from bytelatent.transformer import LMTransformer, LMTransformerArgs
+
+
+def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
+    with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
+        reloaded = json.loads(fr.read())
+
+    torch.set_default_dtype(torch.bfloat16)
+    model_params = reloaded["model"]
+    entropy_model = LMTransformer(
+        LMTransformerArgs(
+            dim=model_params["dim"],
+            n_layers=model_params["n_layers"],
+            n_heads=model_params["n_heads"],
+            max_seqlen=model_params["max_length"],
+            ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
+            vocab_size=model_params["vocab_size"],
+        )
+    )
+
+    entropy_model.load_state_dict(
+        torch.load(state_dict_path, map_location=device), strict=False
+    )
+    entropy_model.to(device)
+    entropy_model = entropy_model.eval()
+    # no grads for the model:
+    for param in entropy_model.parameters():
+        param.requires_grad = False
+    return entropy_model
diff --git a/bytelatent/float8.py b/bytelatent/float8.py
new file mode 100644
index 0000000..6476862
--- /dev/null
+++ b/bytelatent/float8.py
@@ -0,0 +1,152 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import re
+import warnings
+from typing import Callable
+
+import torch
+
+# avoid division by zero when calculating scale
+EPS = 1e-12
+
+
+def scale(t, amax_t, dtype_t):
+    min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
+    scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
+    t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
+    return t_fp8, scale_t
+
+
+def matmul(
+    first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias
+):
+    first_fp8, scale_first = scale(first, amax_first, dtype_first)
+    second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t)
+    output = torch._scaled_mm(
+        first_fp8,
+        second_t_fp8.t(),
+        scale_a=scale_first,
+        scale_b=scale_second_t.t(),
+        bias=bias,
+        out_dtype=torch.bfloat16,
+        use_fast_accum=True,
+    )
+    return output
+
+
+@torch._dynamo.allow_in_graph
+class Fp8LinearFn(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, a, b_t, bias):
+        amax_a = a.abs().amax(dim=-1, keepdim=True)
+        amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
+        out = matmul(
+            a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias
+        )
+
+        ctx.a_requires_grad = a.requires_grad
+        ctx.b_requires_grad = b_t.requires_grad
+        ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
+
+        ctx.save_for_backward(a, b_t, amax_b_t.max())
+
+        return out
+
+    @staticmethod
+    def backward(ctx, grad_out):
+        a, b_t, amax_b = ctx.saved_tensors
+
+        if ctx.a_requires_grad:
+            b = b_t.t().contiguous()
+            amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
+            amax_b = amax_b.repeat(b.shape[0], 1)
+            grad_a = matmul(
+                grad_out,
+                amax_grad_out,
+                torch.float8_e4m3fn,
+                b,
+                amax_b,
+                torch.float8_e4m3fn,
+                None,
+            )
+        else:
+            grad_a = None
+        if ctx.b_requires_grad:
+            grad_b = grad_out.t() @ a
+        else:
+            grad_b = None
+        if ctx.bias_requires_grad:
+            grad_bias = grad_out.sum(dim=0)
+        else:
+            grad_bias = None
+
+        return grad_a, grad_b, grad_bias
+
+
+class Fp8Linear(torch.nn.Linear):
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
+        out = out.unflatten(0, input.shape[:-1])
+        return out
+
+
+def named_replace(
+    fn: Callable[[torch.nn.Module, str], torch.nn.Module],
+    module: torch.nn.Module,
+    name="",
+) -> torch.nn.Module:
+    for child_name, child_module in list(module.named_children()):
+        full_name = f"{name}.{child_name}" if name else child_name
+        new_child_module = named_replace(fn, child_module, full_name)
+        setattr(module, child_name, new_child_module)
+    module = fn(module, name)
+    return module
+
+
+def convert_linears_to_fp8(
+    root_module: torch.nn.Module, recipe: str, filter: str
+) -> torch.nn.Module:
+    if recipe not in ["rowwise"]:
+        raise RuntimeError(f"Unknown float8 recipe {recipe!r}")
+
+    if recipe == "rowwise" and torch.__version__ < "2.5":
+        # We need https://github.com/pytorch/pytorch/pull/134781.
+        warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0")
+
+    # Multi-kernel makes Inductor auto-tune between a regular "streaming"-based
+    # reduction kernel and a "persistent" reduction kernel. Since fp8 has some
+    # multi-pass steps (e.g., first get amax, then scale), persistent kernels
+    # should perform better.
+    torch._inductor.config.triton.multi_kernel = 1
+
+    filter_re = re.compile(filter)
+
+    def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
+        if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
+            return module
+        if type(module) == torch.nn.Linear:
+            if recipe == "rowwise":
+                new_module = Fp8Linear(
+                    in_features=module.in_features,
+                    out_features=module.out_features,
+                    bias=module.bias is not None,
+                    dtype=module.weight.dtype,
+                    device=module.weight.device,
+                )
+                new_module.weight = module.weight
+                new_module.bias = module.bias
+            else:
+                assert False, recipe
+        else:
+            assert False, str(type(module))
+        return new_module
+
+    out = named_replace(replace, root_module)
+
+    # Force re-compile everything
+    torch._dynamo.reset_code_caches()
+    from torch._inductor.cudagraph_trees import reset_cudagraph_trees
+
+    reset_cudagraph_trees()
+
+    return out
diff --git a/bytelatent/logger.py b/bytelatent/logger.py
new file mode 100644
index 0000000..6723a84
--- /dev/null
+++ b/bytelatent/logger.py
@@ -0,0 +1,129 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import logging
+import math
+import sys
+import time
+from datetime import timedelta
+
+from bytelatent.distributed import get_global_rank, get_is_slurm_job
+
+
+class LogFormatter(logging.Formatter):
+    """
+    Custom logger for distributed jobs, displaying rank
+    and preserving indent from the custom prefix format.
+    """
+
+    def __init__(self):
+        self.start_time = time.time()
+        self.rank = get_global_rank()
+        self.show_rank = not get_is_slurm_job()  # srun has --label
+
+    def formatTime(self, record):
+        subsecond, seconds = math.modf(record.created)
+        curr_date = (
+            time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
+            + f".{int(subsecond * 1_000_000):06d}"
+        )
+        delta = timedelta(seconds=round(record.created - self.start_time))
+        return f"{curr_date} - {delta}"
+
+    def formatPrefix(self, record):
+        fmt_time = self.formatTime(record)
+        if self.show_rank:
+            return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
+        else:
+            return f"{record.levelname:<7} {fmt_time} - "
+
+    def formatMessage(self, record, indent: str):
+        content = record.getMessage()
+        content = content.replace("\n", "\n" + indent)
+        # Exception handling as in the default formatter, albeit with indenting
+        # according to our custom prefix
+        if record.exc_info:
+            # Cache the traceback text to avoid converting it multiple times
+            # (it's constant anyway)
+            if not record.exc_text:
+                record.exc_text = self.formatException(record.exc_info)
+        if record.exc_text:
+            if content[-1:] != "\n":
+                content = content + "\n" + indent
+            content = content + indent.join(
+                [l + "\n" for l in record.exc_text.splitlines()]
+            )
+            if content[-1:] == "\n":
+                content = content[:-1]
+        if record.stack_info:
+            if content[-1:] != "\n":
+                content = content + "\n" + indent
+            stack_text = self.formatStack(record.stack_info)
+            content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
+            if content[-1:] == "\n":
+                content = content[:-1]
+
+        return content
+
+    def format(self, record):
+        prefix = self.formatPrefix(record)
+        indent = " " * len(prefix)
+        content = self.formatMessage(record, indent)
+        return prefix + content
+
+
+def set_root_log_level(log_level: str):
+    logger = logging.getLogger()
+    level: int | str = log_level.upper()
+    try:
+        level = int(log_level)
+    except ValueError:
+        pass
+    try:
+        logger.setLevel(level)  # type: ignore
+    except Exception:
+        logger.warning(
+            f"Failed to set logging level to {log_level}, using default 'NOTSET'"
+        )
+        logger.setLevel(logging.NOTSET)
+
+
+def init_logger(
+    log_file: str | None = None,
+    *,
+    name: str | None = None,
+    level: str = "NOTSET",
+):
+    """
+    Setup logging.
+
+    Args:
+        log_file: A file name to save file logs to.
+        name: The name of the logger to configure, by default the root logger.
+        level: The logging level to use.
+    """
+    set_root_log_level(level)
+    logger = logging.getLogger(name)
+
+    # stdout: everything
+    stdout_handler = logging.StreamHandler(sys.stdout)
+    stdout_handler.setLevel(logging.NOTSET)
+    stdout_handler.setFormatter(LogFormatter())
+
+    # stderr: warnings / errors and above
+    stderr_handler = logging.StreamHandler(sys.stderr)
+    stderr_handler.setLevel(logging.WARNING)
+    stderr_handler.setFormatter(LogFormatter())
+
+    # set stream handlers
+    logger.handlers.clear()
+    logger.handlers.append(stdout_handler)
+    logger.handlers.append(stderr_handler)
+
+    if log_file is not None and get_global_rank() == 0:
+        # build file handler
+        file_handler = logging.FileHandler(log_file, "a")
+        file_handler.setLevel(logging.NOTSET)
+        file_handler.setFormatter(LogFormatter())
+        # update logger
+        logger = logging.getLogger()
+        logger.addHandler(file_handler)
diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py
new file mode 100644
index 0000000..77dc4d7
--- /dev/null
+++ b/bytelatent/metrics.py
@@ -0,0 +1,232 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import json
+import logging
+from collections import namedtuple
+from dataclasses import asdict
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any, Union
+
+import torch
+import torch.nn as nn
+import wandb
+from pydantic import BaseModel, ConfigDict
+
+from bytelatent.distributed import get_is_master
+
+logger = logging.getLogger()
+
+
+class WandbArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    job_type: str | None = None
+    dir: str | None = None
+    project: str | None = None
+    entity: str | None = None
+    tags: list | None = None
+    group: str | None = None
+    name: str | None = None
+    notes: str | None = None
+    config_exclude_keys: list[str] | None = None
+    config_include_keys: list[str] | None = None
+    anonymous: str | None = None
+    mode: str | None = None
+    allow_val_change: bool | None = None
+    resume: Union[bool, str] | None = None
+    force: bool | None = None
+    tensorboard: bool | None = None
+    sync_tensorboard: bool | None = None
+    monitor_gym: bool | None = None
+    save_code: bool | None = None
+    id: str | None = None
+    fork_from: str | None = None
+    resume_from: str | None = None
+
+
+class LoggingArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    freq: int = 10  # Log every freq optimizer steps
+    acc_freq: int | None = None  # Log every acc_freq gradient accumulation steps
+
+    wandb: WandbArgs | None = None
+
+
+class MetricLogger:
+    def __init__(self, outdir: Path, args: Any | None = None):
+        self.outdir = outdir
+        self.jsonl_writer = None
+        self.args = args
+
+    def open(self):
+        if self.jsonl_writer is None:
+            self.jsonl_writer = open(self.outdir, "a")
+        if (
+            self.args is not None
+            and self.args.logging.wandb is not None
+            and get_is_master()
+        ):
+            run = wandb.init(
+                config=asdict(self.args),
+                **asdict(self.args.logging.wandb),
+            )
+
+    def log(self, metrics: dict[str, Any]):
+        if (
+            self.args is not None
+            and self.args.logging.wandb is not None
+            and (wandb.run is not None)
+        ):
+            wandb.log(metrics, step=metrics["global_step"])
+
+        metrics.update({"created_at": datetime.now(timezone.utc).isoformat()})
+        print(json.dumps(metrics), file=self.jsonl_writer, flush=True)
+
+    def close(self):
+        if self.jsonl_writer is not None:
+            self.jsonl_writer.close()
+            self.jsonl_writer = None
+
+    def __enter__(self):
+        self.open()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
+    def __del__(self):
+        self.close()
+
+
+GPUMemStats = namedtuple(
+    "GPUMemStats",
+    [
+        "max_active_gib",
+        "max_active_pct",
+        "max_reserved_gib",
+        "max_reserved_pct",
+        "num_alloc_retries",
+        "num_ooms",
+        "power_draw",
+    ],
+)
+
+
+class GPUMemoryMonitor:
+    """
+    Class to monitor GPU memory usage
+    """
+
+    def __init__(self, device: str = "cuda:0"):
+        self.device = torch.device(device)  # device object
+        self.device_name = torch.cuda.get_device_name(self.device)
+        self.device_index = torch.cuda.current_device()
+        self.device_capacity = torch.cuda.get_device_properties(
+            self.device
+        ).total_memory
+        self.device_capacity_gib = self._to_gib(self.device_capacity)
+
+        # reset stats, clear cache
+        torch.cuda.reset_peak_memory_stats()
+        torch.cuda.empty_cache()
+
+    def _to_gib(self, memory_in_bytes):
+        # NOTE: GiB (gibibyte) is 1024, vs GB is 1000
+        _gib_in_bytes = 1024 * 1024 * 1024
+        memory_in_gib = memory_in_bytes / _gib_in_bytes
+        return memory_in_gib
+
+    def _to_pct(self, memory):
+        return 100 * memory / self.device_capacity
+
+    def get_peak_stats(self):
+        cuda_info = torch.cuda.memory_stats(self.device)
+
+        max_active = cuda_info["active_bytes.all.peak"]
+        max_active_gib = self._to_gib(max_active)
+        max_active_pct = self._to_pct(max_active)
+
+        max_reserved = cuda_info["reserved_bytes.all.peak"]
+        max_reserved_gib = self._to_gib(max_reserved)
+        max_reserved_pct = self._to_pct(max_reserved)
+
+        num_retries = cuda_info["num_alloc_retries"]
+        num_ooms = cuda_info["num_ooms"]
+        power_draw = torch.cuda.power_draw()
+
+        if num_retries > 0:
+            logger.warning(f"{num_retries} CUDA memory allocation retries.")
+        if num_ooms > 0:
+            logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
+
+        return GPUMemStats(
+            max_active_gib,
+            max_active_pct,
+            max_reserved_gib,
+            max_reserved_pct,
+            num_retries,
+            num_ooms,
+            power_draw,
+        )
+
+    def reset_peak_stats(self):
+        torch.cuda.reset_peak_memory_stats()
+        torch.cuda.reset_accumulated_memory_stats()
+
+    def __str__(self):
+        mem_stats = self.get_peak_stats()
+        display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, "
+        display_str += (
+            f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak"
+        )
+        return f"{display_str}"
+
+
+def upload_train_to_wandb(
+    ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True
+):
+    import json
+    from pathlib import Path
+
+    import wandb
+    from omegaconf import OmegaConf
+
+    cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml")
+    cfg = OmegaConf.to_container(cfg)
+
+    if train:
+        wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
+
+        with open(Path(ckpt_dir) / "metrics.jsonl") as f:
+            for l in f:
+                m = json.loads(l)
+                wandb.log(m, step=m["global_step"])
+
+        wandb.finish()
+
+    if eval:
+        wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
+
+        with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f:
+            for l in f:
+                m = json.loads(l)
+                wandb.log(
+                    {
+                        f"evals/{name.replace('/','.')}": value
+                        for name, value in m.items()
+                        if "/" in name
+                    },
+                    step=m["global_step"],
+                )
+
+        wandb.finish()
+
+
+def get_num_params(model: nn.Module) -> int:
+    """
+    Get the total model params
+    Args : only_trainable: whether to only count trainable params
+    """
+    numel = {n: p.numel() for n, p in model.named_parameters()}
+    return sum(numel.values())
diff --git a/bytelatent/model/__init__.py b/bytelatent/model/__init__.py
new file mode 100644
index 0000000..71ca4b1
--- /dev/null
+++ b/bytelatent/model/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py
new file mode 100644
index 0000000..9332d19
--- /dev/null
+++ b/bytelatent/model/blt.py
@@ -0,0 +1,1064 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from enum import Enum, auto
+from typing import Any, Optional
+
+import torch
+from pydantic import ConfigDict, model_validator
+from torch import nn
+from torch.nn.attention.flex_attention import create_block_mask
+from typing_extensions import Self
+
+from bytelatent.base_transformer import (
+    BaseTransformerArgs,
+    InitStdFactor,
+    TransformerBlock,
+)
+from bytelatent.data.patcher import Patcher, PatcherArgs
+from bytelatent.model.local_models import LocalDecoder, LocalEncoder
+from bytelatent.model.transformer import GlobalTransformer
+from bytelatent.model.utils import downsample
+from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
+
+
+def attention_flops_per_token(n_layers, seq_len, dim, causal):
+    # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
+    return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
+
+
+def get_num_flop_per_token(
+    num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
+) -> int:
+    return 6 * num_non_embed_params + attention_flops_per_token(
+        n_layers, seq_len, dim, True
+    )
+
+
+def causal_mask(b, h, q_idx, kv_idx):
+    return q_idx >= kv_idx
+
+
+def setattrs(_self, **kwargs):
+    for k, v in kwargs.items():
+        setattr(_self, k, v)
+
+
+def get_encoder_dim_token_emb(args):
+    if args.dim_token is not None:
+        dim_token_emb = args.dim_token
+    elif args.use_local_encoder_transformer:
+        dim_token_emb = args.dim_local_encoder
+    else:
+        dim_token_emb = args.dim_global // args.patch_size
+    return dim_token_emb
+
+
+def get_encoder_dim_patch_emb(args):
+    dim_patch_emb = None
+    if args.cross_attn_encoder:
+        if args.cross_attn_init_by_pooling:
+            dim_patch_emb = args.dim_local_encoder
+        else:
+            dim_patch_emb = args.dim_global
+    return dim_patch_emb
+
+
+def get_global_dim_patch_emb(args):
+    dim_token_emb = get_encoder_dim_token_emb(args)
+    if args.cross_attn_encoder:
+        dim_patch_emb = dim_token_emb * args.cross_attn_k
+    elif (
+        args.downsampling_by_pooling is None
+        or not args.downsampling_by_pooling
+        or len(args.downsampling_by_pooling) == 0
+    ):
+        dim_patch_emb = dim_token_emb * args.patch_size
+    else:
+        dim_patch_emb = dim_token_emb * sum(
+            [
+                pooling in args.downsampling_by_pooling
+                for pooling in ["avg", "min", "max"]
+            ]
+        )
+    return dim_patch_emb
+
+
+def get_decoder_dim_token_emb(args):
+    if args.share_encoder_decoder_emb:
+        dim_token_emb = get_encoder_dim_token_emb(args)
+    elif args.dim_token is not None:
+        dim_token_emb = args.dim_token
+    else:
+        dim_token_emb = args.dim_local_decoder
+    return dim_token_emb
+
+
+def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
+    if ngram_to_size_str is None:
+        return None
+    ngram_to_size = {}
+    for entry in ngram_to_size_str.split(","):
+        ngram, size = entry.split(":")
+        ngram = int(ngram)
+        size = int(size)
+        ngram_to_size[ngram] = size
+    return ngram_to_size
+
+
+def fill_tokens(tokens, patch_size, fill_id):
+    batch_size, seq_len = tokens.shape
+    if seq_len % patch_size == 0:
+        return tokens
+    else:
+        remaining = patch_size - seq_len % patch_size
+        final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
+        return torch.cat((tokens, final_padding), dim=1)
+
+
+def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len):
+    first_patch_length = patch_lengths[0, 0]
+    assert torch.all(
+        first_patch_length == patch_lengths[:, 0]
+    ), "first patch should always be the same size (1 for dynamic, patch_size for static)."
+    assert (
+        first_patch_length - nb_boe == 1
+    ), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})"
+    # Remove first patch from patch_ids for local decoder inputs and shift the last patch.
+    # decoder_patch_lengths = patch_lengths[:, 1:].clone()
+    # decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1)
+    decoder_patch_lengths = patch_lengths[:, 1:]
+    assert (
+        decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]
+        == patch_lengths.sum()
+    ), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}"
+    assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}"
+    decoder_patch_ids = patch_ids_from_lengths(
+        patch_lengths=decoder_patch_lengths, seq_len=seq_len
+    )
+    return decoder_patch_ids
+
+
+primes = [
+    1000000007,
+    5915587277,
+    1500450271,
+    3267000013,
+    5754853343,
+    4093082899,
+    9576890767,
+    3628273133,
+    2860486313,
+    5463458053,
+    3367900313,
+]
+
+
+def rolling_polynomial_hash(t, hash_func_nb: int = 0):
+    prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
+    prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
+    return torch.sum(t * prime_powers, dim=-1)
+
+
+def get_rolling_polynomial_hash_fn(hash_func_nb: int = 0, group_size: int = 2):
+    prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64)
+    prime_powers = torch.stack([prime**i for i in range(group_size)])
+
+    def rolling_polynomial_hash_fn(t):
+        return torch.sum(t * prime_powers, dim=-1)
+
+    return rolling_polynomial_hash_fn
+
+
+def byte_group_hash_function(
+    x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000
+):
+    """
+    Returns a hash of the input x and maps it to a value in the range [0, max_hash].
+
+    expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
+    returns a tensor  of shape (batch_size, seq_len) with values in the range [0, max_hash].
+
+    Note: max hash can make a big difference on the number of collisions.
+    """
+    with torch.no_grad():
+        bs, seq_len = x.shape
+        # x_numpy = x.numpy()
+        # hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False)
+        # for i in range(bs):
+        #     for j in range(seq_len):
+        #         start = max(j, j-group_size+1)
+        #         end = j+1
+        #         hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash)
+
+        prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
+        x = torch.cat([prefix, x], dim=1)
+        windows = x.unfold(1, group_size, 1)
+        # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
+        hashes = rolling_polynomial_hash(windows, hash_func_nb)
+        hash_values_range = hashes % max_hash
+    hash_values_range.requires_grad = False
+    return hash_values_range
+
+
+def create_patch_mask_from_ids(
+    patch_ids, num_patches, window=None, patches_as_queries=False
+):
+    """
+    Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
+    is True if the patch id at position (i, j) is less than or equal to k.
+    Args:
+        patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
+        num_patches (int): Total number of patches.
+        window (int): If not None, only considers patches within a window of size window.
+        patches_as_queries (bool): If True, the patches are used as queries
+    Returns:
+        torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
+    """
+    bs, seq_len = patch_ids.shape
+    if not patches_as_queries:
+        q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
+        kv_ids = (
+            torch.arange(num_patches, device=patch_ids.device)
+            .unsqueeze(0)
+            .unsqueeze(0)
+            .expand(bs, seq_len, num_patches)
+        )
+    else:
+        kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
+        q_ids = (
+            torch.arange(num_patches, device=patch_ids.device)
+            .unsqueeze(0)
+            .unsqueeze(-1)
+            .expand(bs, num_patches, seq_len)
+        )
+    if window is None:
+        mask = q_ids == kv_ids
+    else:
+        mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
+    return mask
+
+
+def cross_attn_mask(
+    patch_ids,
+    patch_lengths,
+    N,
+    patches_as_queries=False,
+    cross_attn_k=1,
+    window=None,
+    block_mask=True,
+):
+    bs = patch_ids.shape[0]
+    with torch.no_grad():
+        # Create the patch mask
+        cross_mask = create_patch_mask_from_ids(
+            patch_ids,
+            patch_lengths.shape[1],
+            window=window,
+            patches_as_queries=patches_as_queries,
+        ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
+        q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
+        kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
+        assert cross_mask.shape == (
+            bs,
+            q_len,
+            kv_len,
+        ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
+        if block_mask:
+
+            def patch_mask(b, h, q_idx, kv_idx):
+                return cross_mask[b, q_idx, kv_idx]
+
+            block_mask = create_block_mask(
+                patch_mask,
+                B=bs,
+                H=None,
+                Q_LEN=q_len,
+                KV_LEN=kv_len,
+                _compile=True,
+            )
+            return block_mask
+        else:
+            return torch.where(
+                cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))
+            ).unsqueeze(
+                1
+            )  # [bs, 1, q_len, kv_len]
+
+
+def get_blt_input(
+    tokens: torch.Tensor,
+    enforce_patch_size_multiple: bool,
+    nb_boe: torch.Tensor,
+    patch_size: int,
+    boe_id: int,
+):
+    """
+        This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
+    tokens respectively.
+
+    Consider the input and target sequences:
+    X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
+    Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
+    with patch_size=4
+
+    Note 1: that there will be no special tokens introduced at the patch level.
+    Note 2: X_e needs to be trimmed to be passed to Global
+
+    Current without boe:
+    X_et = [[boe,boe,boe,boe] [3,4,5,6],      [7,eos,bos,8],    [9,10,eos,bos] [11,12,13, pad]]
+    X_g =  [[boe,boe,boe,boe] [3,4,5,6],      [7,eos,bos,8],    [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
+    X_dt = [[3,4,5,6]         [7,eos,bos,8],  [9,10,eos,bos],   [11,12,13]]
+    Y =    [[4,5,6,7]         [eos,bos,8,9],  [10,eos,bos,11],  [12,13,14]]
+
+    --> lag fix:
+    X_et = [[boe,boe,boe,3]   [4,5,6,7],      [eos,bos,8,9],    [10,eos,bos,11] [12,13,pad,pad]]
+    X_g =  [[boe,boe,boe,3]   [4,5,6,7],      [eos,bos,8,9],    [10,eos,bos,11]]
+    X_dt = [[3,4,5,6]         [7,eos,bos,8],  [9,10,eos,bos],   [11,12,13]]
+    Y =    [[4,5,6,7]    	  [eos,bos,8,9],  [10,eos,bos,11],  [12,13,14]]
+
+    Dynamic (current):
+    X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
+    Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
+
+    entropy patching:
+    input: 7, bos, 9, 10
+    pred (high entropy): eos, 8, 10, eos
+
+    X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
+    X_g =  [[boe],      [3,4,5,6], [7,eos],[bos,8],[9],     [10,eos]]
+    X_dt = [[3,4,5,6],  [7,eos],   [bos,8],[9],    [10,eos],[bos]]
+    Y =    [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
+
+    --> lag fix no boe (force single byte first patch):
+    X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+    X_g =  [[3],        [4,5,6,7], [eos,bos],[8,9], [10],       [eos,bos],      [11,12]] # remove last global patch
+    X_dt = [[3,4,5,6],  [7,eos],   [bos,8], [9],    [10,eos],   [bos,11,12]]
+    Y =    [4,5,6,7,    eos,bos,    8,9,    10,     eos,bos,    11,12,13]
+
+    input: 4, 7, bos, 9, 10
+    pred (high entropy): 5, eos, 8, 10, eos
+
+    X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+    X_g =  [[3],        [4]   ,   [5,6,7], [eos,bos],[8,9], [10],       [eos,bos],      [11,12]] # remove last global patch
+    X_dt = [[3]         [4,5,6],  [7,eos],   [bos,8], [9],    [10,eos],   [bos,11,12]]
+    Y =    [4,]         [5,6,7,    eos,bos,    8,9,    10,     eos,bos,    11,12,13]
+
+    Handle the last byte properly.
+    patch_lengths = [1, 1,         3,      2,         2      1           2               2         1]
+    X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+    X_g =  [[3],        [4]   ,   [5,6,7], [eos,bos],[8,9], [10],       [eos,bos],      [11,12]] # do not remove last global patch
+    X_dt = [[3]         [4,5,6],  [7,eos],   [bos,8], [9],    [10,eos],   [bos,11]       [12]]
+    Y =    [4,]         [5,6,7,    eos,bos,    8,9,    10,     eos,bos,    11,12,        13]]
+
+
+    bpe delim
+    X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12]
+    X_g =  [[3],          [4,5,6,7,<d>],     [eos,bos,<d>], ..
+    X_dt = [[3,4,5,6,7],  [<d>,eos,bos],     [<d>,bos,8], ..
+    Y =    [4,5,6,7,<d>,    eos,bos,<d>       8,9,<d>, ..
+
+
+    Note 1: that there will be no special tokens introduced at the patch level.
+    Note 2: X_e needs to be trimmed to be passed to Global
+    """
+    batch_size, seq_len = tokens.shape
+    local_encoder_tokens = tokens
+    local_decoder_tokens = tokens
+
+    if nb_boe > 0:
+        padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
+        local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
+    # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
+
+    # create global tokens, contains boe tokens and eos
+    # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
+    # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
+    # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
+    # global_tokens += global_tokens.eq(0).int() * boe_id
+    # TODO: fix this when we want to use block causal in the global.
+
+    if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
+        local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
+
+    return local_encoder_tokens, None, local_decoder_tokens
+
+
+def patch_ids_from_lengths(patch_lengths, seq_len):
+    bs, num_patches = patch_lengths.shape
+    # Create a tensor of cumulative sums of the patch lengths
+    cum_d = torch.cat(
+        [
+            torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
+            patch_lengths.cumsum(dim=-1),
+        ],
+        dim=-1,
+    )
+    patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum(
+        dim=-2
+    ) - 1
+    assert not (
+        torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0
+    ), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0"
+    return patch_ids
+
+
+class ByteLatentTransformerArgs(BaseTransformerArgs):
+    model_config = ConfigDict(extra="forbid")
+    # Basic model configuration
+    seed: int = 42
+    vocab_size: int = -1
+    dim: int = 512
+    n_layers: int = 8
+    n_heads: int = 8
+    # TODO: What is the purpose of this parameter?
+    weight_tying: bool = False
+    sliding_window: Optional[int] = None
+
+    # Architecture and dimensions
+    dim_token: int = 256
+    dim_global: int = 512
+    dim_local_decoder: int = 512
+    dim_local_encoder: int = 512
+    n_layers_global: int = 8
+    n_layers_local_decoder: int = 8
+    n_layers_local_encoder: int = 8
+
+    # Tokenization and patching
+    tokenization_mode: str = "bpe"
+    patch_size: float | None = None
+    patching_mode: str | None = None
+    patching_threshold: float | None = None
+    patching_threshold_add: float | None = None
+    monotonicity: bool = False
+    patching_batch_size: int = 1
+    patching_device: str = "cuda"
+    data_loader_patching: bool = False
+    max_patch_length: int | None = None
+
+    # Encoder/Decoder configuration
+    tie_local_encoder_decoder_logits: bool = False
+    use_local_encoder_transformer: bool = False
+    encoder_lm_loss: bool = False
+    max_encoder_seq_length: int | None = None
+    pad_to_max_length: bool = False
+    encoder_enable_byte_ngrams: bool = False
+    encoder_enable_byte_group_hash: bool = False
+    ngram_vocab_sizes: int | None = None
+
+    # Cross attention configurations
+    cross_attn_encoder: bool = False
+    cross_attn_decoder: bool = False
+    cross_attn_window_encoder: int | None = None
+    cross_attn_window_decoder: int | None = None
+    cross_attn_k: int | None = None
+    cross_attn_nheads: int | None = None
+    cross_attn_all_layers_decoder: bool = False
+    cross_attn_all_layers_encoder: bool = False
+    cross_attn_use_flex_attention: bool = True
+    cross_attn_init_by_pooling: bool = False
+
+    # Encoder hash configurations
+    encoder_hash_byte_group_size: Any | None = None
+    encoder_hash_byte_group_vocab: int = 30000
+    encoder_hash_byte_group_nb_functions: int = 3
+
+    # Model behavior and optimization
+    log_patch_lengths: bool = False
+    non_linearity: str = "swiglu"
+    use_rope: bool = True
+    recompute_fc1_out: bool = False
+    recompute_fc3_out: bool = False
+    recompute_attn: bool = True
+    custom_bwd: bool = False
+    layer_ckpt: str = "all"
+    efficient_attn: str | None = None
+
+    # Architecture options
+    patch_only_encoder: bool = False
+    patch_only_decoder: bool = False
+
+    # Initialization and attention
+    init_use_gaussian: bool = True
+    init_use_depth: str = "current"
+    attn_bias_type: str = "causal"
+    alpha_depth: str = "disabled"
+    max_length: int = 2048
+
+    # Norm configuration
+    norm_eps: float = 1e-5
+    norm_affine: bool = True
+    pre_norm: bool = True
+    norm_type: str = "rmsnorm"
+
+    # Additional configurations
+    multiple_of: int = 256
+    ffn_dim_multiplier: float = 1.0
+    dropout: float = 0
+    output_size: int = -1
+
+    # Additional parameters from ModelArgs
+    architecture: str = "vanilla"
+    share_encoder_decoder_emb: bool = True
+    global_local_decoder_residual_layer: str | None = None
+
+    tokenize_with_bpe_delimiter: bool = False
+    patching_thresholds_str: str | None = None
+    tie_local_encoder_decoder: bool = False
+    encoder_preds_low_entropy_toks: float | None = None
+    encoder_preds_random_toks: float | None = None
+    dim_token_emb: int | None = None
+    dim_patch_emb: int | None = None
+
+    encoder_ngram_table_dir: str | None = None
+    encoder_ngram_to_size_str: str | None = None
+
+    # Model architecture params
+    entropy_model_checkpoint_dir: str | None = None
+    entropy_model_is_ngram_model: bool = False
+    downsampling_by_pooling: str | None = None
+    n_heads_global: int = 8
+    n_heads_local_decoder: int = 8
+    n_heads_local_encoder: int = 8
+    n_kv_heads: int | None = None
+    n_kv_heads_global: int | None = None
+    conv_kernel_size: int | None = None
+    local_attention_window_len: int | None = None
+
+    # Performance optimization
+    sequence_parallel: bool = False
+    loss_parallel: bool = False
+    fuse_sequence_parallel: bool = False
+    use_fsdp: bool = True
+    attn_to_keep: str = "all"
+
+    # RoPE parameters
+    rope_theta: float = 10000.0
+    rope_use_fp32_in_outer_product: bool = False
+
+    # Parameter mixing
+    pm_size: int = 0
+
+    # Logging
+    full_logging_n_layers: int = 4
+
+    # Special token config
+    eos_id: int | None = None
+
+    @model_validator(mode="after")
+    def check_hash_byte_sizes(self) -> Self:
+        if (
+            self.encoder_hash_byte_group_size is not None
+            and type(self.encoder_hash_byte_group_size) == str
+        ):
+            self.encoder_hash_byte_group_size = [
+                int(x)
+                for x in self.encoder_hash_byte_group_size.split(",")
+                if len(x) > 0
+            ]
+        return self
+
+
+class LocalEncoderArgs(ByteLatentTransformerArgs):
+    # Local encoder specific dimensions
+    n_heads_local_encoder: int = 8
+    dim_token_emb: int | None = None
+    dim_patch_emb: int | None = None
+
+    def __post_init__(self):
+        # Override base args with local encoder specific values
+        self.dim = self.dim_local_encoder
+        self.n_layers = self.n_layers_local_encoder
+        self.n_heads = self.n_heads_local_encoder
+        self.cross_attn_decoder = False
+        self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
+        self.attn_bias_type = "local_block_causal"
+
+
+class GlobalTransformerArgs(ByteLatentTransformerArgs):
+    # Global encoder specific dimensions
+    dim_token_emb: int | None = None
+    dim_patch_emb: int | None = None
+
+    def __post_init__(self):
+        # Override base args with global encoder specific values
+        self.dim = self.dim_global
+        self.n_layers = self.n_layers_global
+        self.n_heads = self.n_heads_global
+        self.n_kv_heads = self.n_kv_heads_global
+        self.local_attention_window_len = None
+        self.cross_attn_encoder = False
+        self.cross_attn_decoder = False
+
+
+class LocalDecoderArgs(ByteLatentTransformerArgs):
+    # Local decoder specific dimensions
+    dim_token_emb: int | None = None
+    dim_patch_emb: int | None = None
+
+    def __post_init__(self):
+        # Override base args with local decoder specific values
+        self.dim = self.dim_local_decoder
+        self.n_layers = self.n_layers_local_decoder
+        self.n_heads = self.n_heads_local_decoder
+        self.cross_attn_encoder = False
+        self.cross_attn_init_by_pooling = False
+        self.attn_bias_type = "local_block_causal"
+
+
+def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransformer:
+    global_args = args.model_copy(
+        deep=True,
+        update=dict(
+            dim=args.dim_global,
+            n_layers=args.n_layers_global,
+            n_heads=args.n_heads_global,
+            n_kv_heads=args.n_kv_heads_global,
+            local_attention_window_len=None,
+            dim_token_emb=get_global_dim_patch_emb(args),
+            dim_patch_emb=None,
+            cross_attn_encoder=False,
+            cross_attn_decoder=False,
+        ),
+    )
+
+    return GlobalTransformer(global_args)
+
+
+def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
+    # First deep copy the original args
+    # Replace with local encoder specific values
+    local_encoder_args = args.model_copy(
+        deep=True,
+        update=dict(
+            dim=args.dim_local_encoder,
+            n_layers=args.n_layers_local_encoder,
+            n_heads=args.n_heads_local_encoder,
+            dim_token_emb=get_encoder_dim_token_emb(args),
+            dim_patch_emb=get_encoder_dim_patch_emb(args),
+            cross_attn_decoder=False,
+            cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
+            attn_bias_type="local_block_causal",
+        ),
+    )
+
+    return LocalEncoder(local_encoder_args)
+
+
+def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
+    # First deep copy the original args
+    local_decoder_args = args.model_copy(
+        deep=True,
+        update=dict(
+            dim=args.dim_local_decoder,
+            n_layers=args.n_layers_local_decoder,
+            n_heads=args.n_heads_local_decoder,
+            cross_attn_encoder=False,
+            cross_attn_init_by_pooling=False,  # states are already defined
+            dim_token_emb=get_decoder_dim_token_emb(args),
+            dim_patch_emb=args.dim_global,
+            cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
+        ),
+    )
+
+    return LocalDecoder(local_decoder_args)
+
+
+class EmbeddingType(Enum):
+    HASH_TOK = auto()
+    NGRAM = auto()
+
+
+def init_embeddings(
+    args,
+    embedding_type: EmbeddingType,
+    local_encoder_dim: int,
+    encoder_hash_byte_group_size: list = None,
+):
+    if (
+        embedding_type == EmbeddingType.HASH_TOK
+        and args.encoder_hash_byte_group_size is None
+    ):
+        return None
+    if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None:
+        return None
+
+    embeddings = []
+
+    if embedding_type == EmbeddingType.HASH_TOK:
+        emb_dim = local_encoder_dim
+        encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
+        for _ in range(args.encoder_hash_byte_group_nb_functions):
+            for _ in encoder_hash_byte_group_size:
+                embeddings.append(
+                    nn.Embedding(
+                        encoder_hash_byte_group_vocab,
+                        emb_dim,
+                    )
+                )
+
+    elif embedding_type == EmbeddingType.NGRAM:
+        encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str)
+        emb_dim = local_encoder_dim
+        OFFSET = 4  # This should be passed as parameter if it's variable
+        for ngram_vocab_size in encoder_ngram_to_size.values():
+            embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim))
+
+    return nn.ModuleList(embeddings)
+
+
+def compute_hash_embeddings(
+    local_encoder_tokens: torch.Tensor,
+    local_encoder,
+    encoder_hash_tok_embedding: nn.ModuleList,
+    encoder_hash_byte_group_nb_functions: int,
+    encoder_hash_byte_group_size: list,
+    encoder_hash_byte_group_vocab: int,
+) -> torch.Tensor:
+    """
+    Compute embeddings using hash token embeddings.
+
+    Args:
+        local_encoder_tokens: Input tokens tensor
+        local_encoder: Encoder object with tok_embeddings method
+        encoder_hash_tok_embedding: ModuleList of hash token embeddings
+        encoder_hash_byte_group_nb_functions: Number of hash functions
+        encoder_hash_byte_group_size: List of byte group sizes
+        encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
+
+    Returns:
+        torch.Tensor: Combined embeddings
+    """
+    if encoder_hash_tok_embedding is None:
+        return None
+
+    local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
+
+    i = 0
+    for func_nb in range(encoder_hash_byte_group_nb_functions):
+        for byte_group_size in encoder_hash_byte_group_size:
+            hash_ids = byte_group_hash_function(
+                local_encoder_tokens,
+                byte_group_size,
+                hash_func_nb=func_nb,
+                max_hash=encoder_hash_byte_group_vocab,
+            )
+            hash_tok_embedding = encoder_hash_tok_embedding[i]
+            local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
+            i += 1
+
+    assert i == len(encoder_hash_tok_embedding)
+    return local_encoder_embeds
+
+
+class ByteLatentTransformer(nn.Module):
+    """
+    The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
+    by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
+    and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for
+    improved performance and inference efficiency.
+    """
+
+    def __init__(self, args: ByteLatentTransformerArgs):
+        super().__init__()
+
+        # General configuration
+        self.weight_tying = args.weight_tying
+        self.sliding_window = args.sliding_window
+        self.patch_size = args.patch_size
+        self.patching_mode = args.patching_mode
+        self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
+            BOE_ID,
+            BOS_ID,
+            PAD_ID,
+            EOS_ID,
+        )
+        self.downsampling_by_pooling = args.downsampling_by_pooling
+        self.patching_threshold = args.patching_threshold
+        self.dim = args.dim
+        self.init_base_std = args.init_base_std
+        self.init_std_factor = InitStdFactor(args.init_std_factor)
+        self.max_seqlen = args.max_seqlen
+
+        # Cross attention configuration
+        self.cross_attn_encoder = args.cross_attn_encoder
+        self.cross_attn_decoder = args.cross_attn_decoder
+        self.cross_attn_k = args.cross_attn_k
+        self.cross_attn_window_encoder = args.cross_attn_window_encoder
+        self.cross_attn_window_decoder = args.cross_attn_window_decoder
+        self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention
+
+        # Encoder hash configuration
+        self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size
+        self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
+        self.encoder_hash_byte_group_nb_functions = (
+            args.encoder_hash_byte_group_nb_functions
+        )
+
+        # ByteLatent modules
+        self.local_encoder = create_local_encoder(args)
+        self.global_transformer = create_global_transformer(args)
+        self.local_decoder = create_local_decoder(args)
+        self.encoder_hash_tok_embedding = init_embeddings(
+            args,
+            EmbeddingType.HASH_TOK,
+            local_encoder_dim=self.local_encoder.dim,
+            encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
+        )
+        self.encoder_ngram_embedding = init_embeddings(
+            args,
+            EmbeddingType.NGRAM,
+            local_encoder_dim=self.local_encoder.dim,
+            encoder_hash_byte_group_size=None,
+        )
+        self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
+
+        # Transformer layers
+        self.layers = nn.ModuleList(
+            [TransformerBlock(args) for _ in range(args.n_layers)]
+        )
+
+        # Encoder ngram embedding tables
+        self.encoder_ngram_embedding = None
+        if args.encoder_enable_byte_ngrams:
+            self.encoder_ngram_embedding = nn.ModuleList()
+            assert args.ngram_vocab_sizes is not None
+            self.encoder_ngram_to_size = parse_ngram_to_size(
+                args.encoder_ngram_to_size_str
+            )
+            ngram_emb_dim = self.local_encoder.dim
+            for ngram_vocab_size in self.encoder_ngram_to_size.values():
+                self.encoder_ngram_embedding.append(
+                    nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim)
+                )
+
+        # Output layer
+        assert args.vocab_size > 0, "vocab_size must be greater than 0"
+        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
+        if args.weight_tying:
+            self.output.weight = self.tok_embeddings.weight
+
+        # Patcher module
+        if not args.data_loader_patching:
+            self.patcher = Patcher(
+                PatcherArgs(
+                    patch_size=args.patch_size,
+                    patching_mode=args.patching_mode,
+                    patching_threshold=args.patching_threshold,
+                    patching_threshold_add=args.patching_threshold_add,
+                    monotonicity=args.monotonicity,
+                    max_patch_length=args.max_patch_length,
+                )
+            )
+
+    def forward(
+        self,
+        tokens: torch.Tensor,
+        patch_lengths: Optional[torch.Tensor] = None,
+        ngram_ids: Optional[torch.Tensor] = None,
+    ):
+        # Ensure ngram_ids is either a tensor or None
+        assert (
+            isinstance(ngram_ids, torch.Tensor) or ngram_ids is None
+        ), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}"
+
+        bs, N = tokens.shape  # Batch size and sequence length
+
+        # Get megabyte inputs
+        nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1)
+        local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
+            tokens=tokens,
+            enforce_patch_size_multiple=False,
+            nb_boe=nb_boe,
+            patch_size=self.patch_size,
+            boe_id=self.boe_id,
+        )
+
+        # Patching
+        if patch_lengths is None:
+            assert (
+                getattr(self, "patcher", None) is not None
+            ), "Patcher not defined and no patch_lengths passed."
+            patch_lengths, tok_scores = self.patcher.patch(
+                local_encoder_tokens,
+                include_next_token=True,
+                threshold=self.patcher.threshold,
+            )
+        else:
+            if nb_boe > 0:
+                patch_lengths[:, 0] += nb_boe
+
+        assert torch.min(patch_lengths) >= 0
+
+        # Generate patch IDs from patch_lengths
+        patch_ids = patch_ids_from_lengths(
+            patch_lengths, local_encoder_tokens.shape[-1]
+        )
+        assert torch.max(patch_ids) + 1 <= torch.max(
+            (patch_lengths != 0).sum(dim=-1)
+        ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
+
+        cross_attn_mask_enc = None
+        # Cross-attention encoder
+        if self.cross_attn_encoder:
+            cross_attn_mask_enc = cross_attn_mask(
+                patch_ids,
+                patch_lengths,
+                N,
+                patches_as_queries=True,
+                cross_attn_k=self.cross_attn_k,
+                window=self.cross_attn_window_encoder,
+                block_mask=self.cross_attn_use_flex_attention,
+            )
+
+        # Hashing and embedding
+        local_encoder_embeds = compute_hash_embeddings(
+            local_encoder_tokens=local_encoder_tokens,
+            local_encoder=self.local_encoder,
+            encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
+            encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions,
+            encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
+            encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab,
+        )
+
+        # N-gram table embeddings
+        if self.encoder_ngram_embedding is not None:
+            assert ngram_ids is not None, "ngram_ids must be provided"
+            if local_encoder_embeds is None:
+                local_encoder_embeds = self.local_encoder.tok_embeddings(
+                    local_encoder_tokens
+                )
+            assert len(ngram_ids) == len(
+                self.encoder_ngram_embedding
+            ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}"
+            for i in range(ngram_ids.shape[0]):
+                ngram_embedding = self.encoder_ngram_embedding[i]
+                ngram_embeds = ngram_embedding(ngram_ids[i])
+                assert (
+                    local_encoder_embeds.shape == ngram_embeds.shape
+                ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}"
+                local_encoder_embeds = local_encoder_embeds + ngram_embeds
+
+        # Local encoder
+        h_cross = None
+        (h_encoder, h_cross), cache_encoder = self.local_encoder(
+            tokens=local_encoder_tokens,
+            embeds=local_encoder_embeds,
+            patch_embeds=h_cross if self.cross_attn_encoder else None,
+            cross_mask=cross_attn_mask_enc,
+            num_patches=patch_lengths.shape[1],
+            patch_ids=patch_ids,
+        )
+
+        # Downsampling
+        if not self.cross_attn_encoder:
+            assert (
+                patch_ids.shape[1] == h_encoder.shape[1]
+            ), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}"
+            h = downsample(
+                h_encoder,
+                patch_lengths.shape[1],
+                patch_lengths,
+                patch_ids,
+                downsampling_by_pooling=self.downsampling_by_pooling,
+                patch_size=self.patch_size,
+            )
+        else:
+            # Reshape h_cross
+            h = h_cross.view(bs, patch_lengths.shape[1], -1)
+
+        # Global transformer
+        global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id)
+        rows, cols = torch.where(local_encoder_tokens == self.eos_id)
+        eos_patch_ids = patch_ids[rows, cols]
+        global_tokens[rows, eos_patch_ids] = self.eos_id
+
+        h, _ = self.global_transformer(
+            embeds=h,
+            tokens=global_tokens,
+        )
+
+        # Unpatching
+        dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
+
+        # Generate decoder patch IDs
+        decoder_patch_ids = decoder_patch_ids_from_lengths(
+            patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
+        )
+        assert (
+            torch.max(decoder_patch_ids) + 1 <= h.shape[1]
+        ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
+        assert (
+            decoder_patch_ids.shape[1] == dec_embeds.shape[1]
+        ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
+
+        # Cross-attention decoder
+        if not self.cross_attn_decoder:
+            h = torch.gather(
+                h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
+            )
+            cross_attn_mask_dec = None
+            assert local_decoder_tokens.shape == h.shape[:-1]
+        else:
+            cross_attn_mask_dec = cross_attn_mask(
+                decoder_patch_ids,
+                patch_lengths,
+                N,
+                patches_as_queries=False,
+                cross_attn_k=self.cross_attn_k,
+                window=self.cross_attn_window_decoder,
+                block_mask=self.cross_attn_use_flex_attention,
+            )
+
+        # Local decoder
+        output, _ = self.local_decoder(
+            embeds=dec_embeds,
+            patch_embeds=h,
+            tokens=local_decoder_tokens,
+            cross_mask=cross_attn_mask_dec,
+        )
+        return output
+
+    def reset_parameters(self, init_std=None):
+        # Either use fixed base std or sqrt model dim
+        init_std = init_std or (self.dim ** (-0.5))
+        nn.init.trunc_normal_(
+            self.tok_embeddings.weight,
+            mean=0.0,
+            std=init_std,
+            a=-3 * init_std,
+            b=3 * init_std,
+        )
+        if not self.weight_tying:
+            nn.init.trunc_normal_(
+                self.output.weight,
+                mean=0.0,
+                std=init_std,
+                a=-3 * init_std,
+                b=3 * init_std,
+            )
+
+    def init_weights(self):
+        self.reset_parameters()
+        self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
+        for depth, layer in enumerate(self.layers):
+            factor = {
+                InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
+                InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
+                InitStdFactor.DIM_RATIO: self.dim / 4096,
+                InitStdFactor.DISABLED: 1.0,
+            }[self.init_std_factor]
+
+            layer.init_weights(self.init_base_std, factor)
+
+        self.local_decoder.init_weights(self.init_base_std)
+        self.global_transformer.init_weights(self.init_base_std)
+        self.local_encoder.init_weights(self.init_base_std)
+
+        for emb in self.encoder_hash_tok_embedding:
+            nn.init.trunc_normal_(
+                emb.weight,
+                mean=0.0,
+                std=self.init_base_std,
+                a=-3 * self.init_base_std,
+                b=3 * self.init_base_std,
+            )
diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py
new file mode 100644
index 0000000..8255504
--- /dev/null
+++ b/bytelatent/model/local_models.py
@@ -0,0 +1,356 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import logging
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn.attention.flex_attention import BlockMask
+from xformers.ops import AttentionBias
+
+from bytelatent.base_transformer import (
+    InitStdFactor,
+    RMSNorm,
+    RotaryEmbedding,
+    TransformerBlock,
+)
+from bytelatent.model.transformer import CrossAttention
+from bytelatent.model.utils import create_causal_mask, downsample
+from bytelatent.tokenizers.blt_tokenizer import BOE_ID
+
+logger = logging.getLogger()
+
+
+class LocalModelBase(nn.Module):
+    def __init__(self, args):
+        super().__init__()
+
+        self.dim = args.dim
+        self.dropout = args.dropout
+        self.vocab_size = args.vocab_size + args.pm_size
+        self.patch_size = args.patch_size
+
+        self.efficient_attn = args.efficient_attn
+        self.sliding_window = args.sliding_window
+        self.use_rope = args.use_rope
+        self.init_std_factor = args.init_std_factor
+        self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
+        self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
+        self.cross_attn_k = getattr(args, "cross_attn_k", None)
+
+        self.boe_id = BOE_ID
+
+        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
+        self.layers = nn.ModuleList(
+            [TransformerBlock(args) for _ in range(args.n_layers)]
+        )
+
+        self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
+        if not self.use_rope:
+            self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
+        else:
+            self.rope = RotaryEmbedding(
+                theta=args.rope_theta,
+                head_dim=args.head_dim or args.dim // args.n_heads,
+                max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
+            )
+            self.pos_embeddings = None
+
+        self.token_embedding_projection = (
+            nn.Linear(args.dim_token_emb, args.dim, bias=False)
+            if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
+            else None
+        )
+
+        self.patch_embedding_projection = self._create_patch_projection(args)
+
+    def _should_create_patch_projection(self, args):
+        dimension_mismatch = (
+            getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
+        )
+
+        # Check cross attention conditions
+        cross_attn_conditions = (
+            hasattr(args, "cross_attn_encoder")
+            and args.cross_attn_encoder
+            and getattr(args, "cross_attn_init_by_pooling")
+        ) or (
+            hasattr(args, "cross_attn_decoder")
+            and args.cross_attn_decoder
+            and getattr(args, "cross_attn_init_by_pooling")
+        )
+
+        return dimension_mismatch or cross_attn_conditions
+
+    def _create_patch_projection(self, args):
+        if not self._should_create_patch_projection(args):
+            return None
+
+        output_dim = args.dim_token_emb * (self.cross_attn_k or 1)
+
+        return nn.Linear(
+            in_features=args.dim_patch_emb,
+            out_features=output_dim,
+            bias=False,
+        )
+
+    def apply_embedding(self, tokens, embeds):
+        if embeds is not None:
+            return embeds
+        else:
+            return self.tok_embeddings(tokens)
+
+    def init_weights(self, init_std=None):
+        self.rope.reset_parameters()
+
+        init_std = init_std or (self.dim ** (-0.5))
+        nn.init.trunc_normal_(
+            self.tok_embeddings.weight,
+            mean=0.0,
+            std=init_std,
+            a=-3 * init_std,
+            b=3 * init_std,
+        )
+        if self.pos_embeddings is not None:
+            nn.init.trunc_normal_(
+                self.pos_embeddings.weight,
+                mean=0.0,
+                std=init_std,
+                a=-3 * init_std,
+                b=3 * init_std,
+            )
+
+        for depth, layer in enumerate(self.layers):
+            factor = {
+                InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
+                InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
+                InitStdFactor.DIM_RATIO: self.dim / 4096,
+                InitStdFactor.DISABLED: 1.0,
+            }[self.init_std_factor]
+
+            layer.init_weights(init_std, factor)
+
+        if self.token_embedding_projection is not None:
+            nn.init.trunc_normal_(
+                self.token_embedding_projection.weight,
+                mean=0.0,
+                std=init_std,
+                a=-3 * init_std,
+                b=3 * init_std,
+            )
+
+        if self.patch_embedding_projection is not None:
+            nn.init.trunc_normal_(
+                self.patch_embedding_projection.weight,
+                mean=0.0,
+                std=init_std,
+                a=-3 * init_std,
+                b=3 * init_std,
+            )
+
+        if hasattr(self, "output"):
+            nn.init.trunc_normal_(
+                self.output.weight,
+                mean=0.0,
+                std=init_std,
+                a=-3 * init_std,
+                b=3 * init_std,
+            )
+
+        if self.cross_attn_layers is not None:
+            for depth, layer in enumerate(self.cross_attn_layers):
+                factor = {
+                    InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
+                    InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
+                    InitStdFactor.DIM_RATIO: self.dim / 4096,
+                    InitStdFactor.DISABLED: 1.0,
+                }[self.init_std_factor]
+
+                layer.init_weights(init_std, factor)
+
+
+class LocalEncoder(LocalModelBase):
+    def __init__(self, args):
+        super().__init__(args)
+        self.output_proj = (
+            args.patching_mode in ["entropy", "probmax"]
+        ) and args.entropy_model_checkpoint_dir is None
+
+        self.apply_transformer = args.use_local_encoder_transformer
+        self.downsampling_by_pooling = args.downsampling_by_pooling
+        self.patch_only = args.patch_only_encoder
+        self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
+        self.cross_attn_encoder = args.cross_attn_encoder
+        self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
+        self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
+        self.cross_attn_nheads = args.cross_attn_nheads
+
+        if self.cross_attn_encoder:
+            self.cross_attn_layers = torch.nn.ModuleList()
+            layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
+            for _ in range(layers_to_add):
+                self.cross_attn_layers.append(
+                    CrossAttention(
+                        dim=self.dim,
+                        head_dim=self.dim // self.cross_attn_nheads,
+                        n_heads=self.cross_attn_nheads,
+                        n_kv_heads=self.cross_attn_nheads,
+                        norm_eps=args.norm_eps,
+                    )
+                )
+
+    def apply_embedding(self, tokens, embeds):
+        if embeds is not None:
+            assert (
+                self.expects_hash_embeddings
+            ), "Not expecting embeddings to be passed."
+            return embeds
+        else:
+            return self.tok_embeddings(tokens)
+
+    def forward(
+        self,
+        tokens: torch.Tensor,
+        embeds: Optional[torch.Tensor] = None,
+        patch_embeds: Optional[torch.Tensor] = None,
+        mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
+        cross_mask: Optional[torch.Tensor] = None,
+        num_patches: Optional[int] = None,
+        patch_ids: Optional[torch.Tensor] = None,
+        cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+    ):
+        """ """
+        bs, seqlen = tokens.shape
+        if mask is None:
+            mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
+
+        h = self.apply_embedding(tokens, embeds)
+        freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
+
+        h = F.dropout(h, p=self.dropout, training=self.training)
+
+        for i, layer in enumerate(self.layers):
+            h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
+            # check if cross attention should be applied to either all layer or only the last layer
+            if self.cross_attn_encoder and (
+                i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
+            ):
+                patch_embeds = self.apply_cross_attention(
+                    h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask
+                )
+
+        h_residual = patch_embeds if self.cross_attn_encoder else None
+        return (h, h_residual), cache
+
+    def apply_cross_attention(
+        self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask
+    ):
+        # apply pooling and project
+        if self.cross_attn_init_by_pooling and patch_embeds is None:
+            patch_embeds = downsample(
+                h,
+                num_patches,
+                patch_ids=patch_ids,
+                downsampling_by_pooling=self.downsampling_by_pooling,
+                patch_size=self.patch_size,
+            )
+            if self.patch_embedding_projection is not None:
+                patch_embeds = self.patch_embedding_projection(patch_embeds)
+                patch_embeds = patch_embeds.reshape(
+                    bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
+                )
+
+        layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0
+        patch_embeds_cross = self.cross_attn_layers[layer_idx](
+            x=patch_embeds,
+            kv=h,
+            mask=cross_mask,
+        )
+        patch_embeds += patch_embeds_cross
+        return patch_embeds
+
+
+class LocalDecoder(LocalModelBase):
+    def __init__(self, args):
+        super().__init__(args)
+
+        # Model configuration flags
+        self.patch_only = args.patch_only_decoder
+        self.expects_embeddings = args.share_encoder_decoder_emb
+        self.cross_attn_decoder = args.cross_attn_decoder
+        self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
+        self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
+        self.cross_attn_nheads = args.cross_attn_nheads
+
+        if self.cross_attn_decoder:
+            self.cross_attn_layers = torch.nn.ModuleList()
+            layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
+            for _ in range(layers_to_add):
+                self.cross_attn_layers.append(
+                    CrossAttention(
+                        dim=self.dim,
+                        head_dim=self.dim // self.cross_attn_nheads,
+                        n_heads=self.cross_attn_nheads,
+                        n_kv_heads=self.cross_attn_nheads,
+                        norm_eps=args.norm_eps,
+                    )
+                )
+
+        self.output = nn.Linear(
+            self.dim,
+            args.vocab_size,
+            bias=False,
+        )
+
+    def forward(
+        self,
+        tokens: torch.Tensor,
+        embeds: Optional[torch.Tensor],
+        patch_embeds: Optional[torch.Tensor] = None,
+        mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
+        cross_mask: Optional[torch.Tensor] = None,
+        cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+    ):
+        bs, seqlen = tokens.shape
+        assert embeds is not None, "Embeddings must be provided"
+
+        if mask is None:
+            mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
+
+        h = embeds
+
+        if self.patch_embedding_projection is not None:
+            assert patch_embeds is not None, "Patch embeddings must be passed."
+            patch_embeds = self.patch_embedding_projection(patch_embeds)
+            if self.cross_attn_k is not None:
+                patch_embeds = patch_embeds.reshape(
+                    bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
+                )
+
+        if patch_embeds is not None and not self.cross_attn_decoder:
+            h = h + patch_embeds
+
+        freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
+
+        h = F.dropout(h, p=self.dropout, training=self.training)
+        for i, layer in enumerate(self.layers):
+            if self.cross_attn_decoder and (
+                i == 0 or self.cross_attn_all_layers_decoder
+            ):
+                # Use cross attention to extract info from patch_embeds into h
+                h_cross = self.cross_attn_layers[i](
+                    x=h,
+                    kv=patch_embeds,
+                    mask=cross_mask,
+                )
+                h = h + h_cross
+
+            h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
+
+        h_preds = self.norm(h)
+        h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
+        h_preds = self.output(h_preds)
+        h_preds = h_preds.float()
+        return h_preds, cache
diff --git a/bytelatent/model/transformer.py b/bytelatent/model/transformer.py
new file mode 100644
index 0000000..24dc057
--- /dev/null
+++ b/bytelatent/model/transformer.py
@@ -0,0 +1,199 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import logging
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn.attention.flex_attention import BlockMask
+from xformers.ops import AttentionBias
+
+from bytelatent.base_transformer import (
+    BaseTransformer,
+    RMSNorm,
+    flex_attention_comp,
+    repeat_kv,
+)
+from bytelatent.model.utils import create_causal_mask
+
+logger = logging.getLogger()
+
+
+class CrossAttention(nn.Module):
+    """
+    CrossAttention block to attend to the encoder states from the decoder.
+    Rope is not supported.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        head_dim: int,
+        n_heads: int,
+        n_kv_heads: int,
+        norm_eps: float,
+    ):
+        super().__init__()
+
+        self.dim = dim
+        self.head_dim = head_dim
+
+        self.n_heads = n_heads
+        self.n_kv_heads = n_kv_heads
+        self.heads_per_group = self.n_heads // self.n_kv_heads
+
+        self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
+        self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
+
+        self.wq = nn.Linear(
+            dim,
+            n_heads * head_dim,
+            bias=False,
+        )
+        self.wk = nn.Linear(
+            dim,
+            n_kv_heads * head_dim,
+            bias=False,
+        )
+        self.wv = nn.Linear(
+            dim,
+            n_kv_heads * head_dim,
+            bias=False,
+        )
+
+        self.wo = nn.Linear(
+            n_heads * head_dim,
+            dim,
+            bias=False,
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        kv: torch.Tensor,
+        mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
+    ) -> torch.Tensor:
+        # B S D
+        bsz, seq_len, _ = x.shape
+        _, slen_kv, _ = kv.shape
+        x = self.cross_attn_norm_q(x)
+        kv = self.cross_attn_norm_kv(kv)
+
+        xq = self.wq(x)
+        xk = self.wk(kv)
+        xv = self.wv(kv)
+
+        output_shape = xq.shape
+        # B S D -> B S H D
+        xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
+        xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
+        xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
+
+        xk = repeat_kv(xk, self.heads_per_group, dim=2)
+        xv = repeat_kv(xv, self.heads_per_group, dim=2)
+
+        assert mask is None or isinstance(mask, BlockMask)
+        xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
+        output = flex_attention_comp(xq, xk, xv, block_mask=mask)
+        output = output.transpose(1, 2).contiguous()  # B H S D -> B S H D
+
+        output = self.wo(output.reshape(output_shape))
+
+        return x + output
+
+    def init_weights(self, base_std: float, factor: float = 1.0):
+        std = base_std * factor
+
+        nn.init.trunc_normal_(
+            self.wq.weight,
+            mean=0.0,
+            std=std,
+            a=-3 * std,
+            b=3 * std,
+        )
+
+        nn.init.trunc_normal_(
+            self.wk.weight,
+            mean=0.0,
+            std=std,
+            a=-3 * std,
+            b=3 * std,
+        )
+
+        nn.init.trunc_normal_(
+            self.wv.weight,
+            mean=0.0,
+            std=std,
+            a=-3 * std,
+            b=3 * std,
+        )
+
+        output_std = std / (2**0.5)
+        nn.init.trunc_normal_(
+            self.wo.weight,
+            mean=0.0,
+            std=output_std,
+            a=-3 * output_std,
+            b=3 * output_std,
+        )
+        self.cross_attn_norm_q.reset_parameters()
+        self.cross_attn_norm_kv.reset_parameters()
+
+
+class GlobalTransformer(BaseTransformer):
+    def __init__(self, args):
+        super().__init__(args)
+        self.dropout = args.dropout
+        self.sliding_window = args.sliding_window
+        self.efficient_attn = args.efficient_attn
+
+        self.token_embedding_projection = None
+        if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
+            self.token_embedding_projection = nn.Linear(
+                args.dim_token_emb,
+                args.dim,
+                bias=False,
+            )
+
+    def forward(
+        self,
+        tokens: torch.Tensor,
+        tok_idx: Optional[torch.Tensor] = None,
+        embeds: Optional[torch.Tensor] = None,
+        mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
+        cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+    ):
+        """
+        Similar to BaseTransformer.forward, but with an additional embeds argument
+        and projection to the token space.
+        """
+        bs, seqlen = tokens.shape
+        attn_impl = self.efficient_attn
+
+        h = embeds
+
+        mask = (
+            mask
+            if mask is not None
+            else create_causal_mask(seqlen, attn_impl, self.sliding_window)
+        )
+
+        if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
+            h = self.token_embedding_projection(h)
+
+        h = F.dropout(h, p=self.dropout, training=self.training)
+
+        h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
+        return h, cache
+
+    def init_weights(self, init_base_std: float):
+        super().init_weights()
+        if self.token_embedding_projection is not None:
+            nn.init.trunc_normal_(
+                self.token_embedding_projection.weight,
+                mean=0.0,
+                std=init_base_std,
+                a=-3 * init_base_std,
+                b=3 * init_base_std,
+            )
diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py
new file mode 100644
index 0000000..ce52a30
--- /dev/null
+++ b/bytelatent/model/utils.py
@@ -0,0 +1,116 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import torch
+from torch.nn.attention.flex_attention import create_block_mask
+from xformers.ops import fmha
+
+
+def patch_reduce(h, max_num_patches, reduction, patch_ids):
+    """
+    Reduce variable length patches to single embedding per patch
+    Note: this works with variable number of patches for different sequences in the batch
+    It handles variable length patches by assuming that patch_lengths will be 0 for any
+    extra patches on the *right*. Since there can be a variable number of patches
+    this function also return the number of patches for each sequence in the batch.
+    Any embeddings on the right that are not allocated to a patch
+    (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
+    will be sent to a dummy patch, which is trimmed before returning.
+    """
+    bs, seq_len, emb_dim = h.shape
+
+    patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
+
+    reduced_embs = torch.zeros(
+        (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device
+    )
+    reduced_embs = reduced_embs.scatter_reduce(
+        src=h,
+        dim=1,
+        index=patch_ids,
+        reduce=reduction,
+        include_self=False,
+    )
+    reduced_embs = reduced_embs[:, :max_num_patches, :]
+
+    return reduced_embs
+
+
+def concat_downsample(h, patch_lengths, patch_size):
+    # The assumption in this function is that seq_len = patch_size * num_patches.
+    bs, seq_len, emb_dim = h.shape
+    patch_end_ids = torch.cumsum(patch_lengths, dim=1)
+    patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to(
+        patch_end_ids.device
+    )
+    # Is clamp ok here?
+    patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1])
+    patch_ids = patch_ids.view(bs, -1, emb_dim)
+    # after gather h.shape = [batch_size, seq_len, dim]
+    h = torch.gather(h, 1, patch_ids)
+    h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1))
+    return h
+
+
+def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids):
+    cat = []
+    if "avg" in pooling_mode or "mean" in pooling_mode:
+        cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids))
+    if "min" in pooling_mode:
+        cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids))
+    if "max" in pooling_mode:
+        cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids))
+    assert len(cat) > 0
+    h = torch.cat(cat, dim=-1)
+    return h
+
+
+def downsample(
+    h,
+    num_patches,
+    patch_lengths=None,
+    patch_ids=None,
+    downsampling_by_pooling=None,
+    patch_size=4,
+):
+    """
+    Downsampling:
+        a. concatenating embeddings in the patch
+            Note: with dynamic patching, patch the last patch_size tokens.
+        b. pooling embeddings in the patch
+    """
+    # input: h.shape = [batch_size, seq_len, dim]
+    # input: pool h.shape = [batch_size, seq_len / patch_size, dim]
+    # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep
+    if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0:
+        # By pooling
+        max_num_patches = num_patches
+        assert patch_ids is not None
+        h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids)
+    else:
+        # TODO: remove this condition
+        # By concatenating (fixed lengths patching)
+        assert patch_lengths is not None
+        h = concat_downsample(h, patch_lengths, patch_size)
+    return h
+
+
+def causal_mask(b, h, q_idx, kv_idx):
+    return q_idx >= kv_idx
+
+
+def create_causal_mask(seqlen, attn_impl, sliding_window):
+    if sliding_window is not None and attn_impl == "xformers":
+        return fmha.attn_bias.LocalAttentionFromBottomRightMask(
+            window_left=sliding_window - 1, window_right=0
+        )
+    elif attn_impl == "xformers":
+        return fmha.attn_bias.LowerTriangularMask()
+    elif attn_impl == "sdpa":
+        return "causal"
+    elif attn_impl == "flex_attention":
+        return create_block_mask(causal_mask, None, None, seqlen, seqlen)
+    elif attn_impl == "fmha":
+        return None
+    else:
+        raise NotImplementedError(
+            f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
+        )
diff --git a/bytelatent/optim.py b/bytelatent/optim.py
new file mode 100644
index 0000000..c6e1829
--- /dev/null
+++ b/bytelatent/optim.py
@@ -0,0 +1,162 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import logging
+import math
+from functools import partial
+
+from pydantic import BaseModel, ConfigDict
+from torch import nn
+from torch.optim import AdamW, lr_scheduler
+
+logger = logging.getLogger()
+
+
+class OptimArgs(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+    lr: float = 3e-4
+    weight_decay: float = 0.1
+    epsilon: float = 1e-8
+    beta1: float = 0.9
+    beta2: float = 0.95
+    clip: float = 1.0
+
+    scheduler: str = "cosine"
+    warmup: int = 2000
+    lr_min_ratio: float = 0.1
+    cycle_length: float = 1.0
+    cosine_theta: float = 1.0
+    annealing_step: int = 1000
+    decay_fraction: float = 0.1
+
+    exp_factor: float = 0.5
+
+
+def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float:
+    if step < warmup:
+        lr = float(step) / warmup
+    elif step <= n_steps:
+        s = float(step - warmup) / (n_steps - warmup)
+        lr = s * min_ratio + (1 - s)
+    else:
+        lr = min_ratio
+    return lr
+
+
+def lr_inv_sqrt(step: int, warmup: int, exp_factor: float, min_ratio: float) -> float:
+    if step < warmup:
+        lr = float(step) / warmup
+    else:
+        lr = max((warmup**exp_factor) / (step**exp_factor), min_ratio)
+    return lr
+
+
+def lr_cosine(
+    step: int,
+    warmup: int,
+    n_steps: int,
+    cycle_length: float,
+    theta: float,
+    min_ratio: float,
+) -> float:
+    sign = ((step // (n_steps * cycle_length)) % 2) * -2 + 1
+    if step < warmup:
+        lr = float(step) / warmup
+    elif step <= n_steps:
+        s = float(step - warmup) / (n_steps - warmup)
+        lr = min_ratio + 0.5 * (1 - min_ratio) * (
+            sign * math.cos(math.pi * s**theta / cycle_length) + 1
+        )
+    else:
+        lr = min_ratio
+    return lr
+
+
+def lr_wsd(
+    step: int,
+    warmup: int,
+    n_steps: int,
+    decay_fraction: float,
+    cycle_length: float,
+    min_ratio: float,
+) -> float:
+    """
+    UNDERSTANDING WARMUP-STABLE-DECAY LEARNING RATES: A RIVER VALLEY LOSS LANDSCAPE PERSPECTIVE
+    https://arxiv.org/pdf/2410.05192
+    """
+    cycle_num = step // int(n_steps * cycle_length) + 1
+    curr_n_steps = int(n_steps * cycle_length) * cycle_num
+    decay_length = int(curr_n_steps * decay_fraction)
+
+    if step < warmup:
+        lr = float(step) / warmup
+    elif step <= curr_n_steps - decay_length:
+        lr = 1.0
+    elif step > curr_n_steps - decay_length and step <= curr_n_steps:
+        # Linear interpolation gives similar results
+        # slope = -(1.0 - min_ratio) / decay_length
+        # intercept = min_ratio + ((1.0 - min_ratio) * curr_n_steps) / decay_length
+        # lr = slope * step + intercept
+
+        step = step - (curr_n_steps - decay_length)
+        lr = 1 / ((step / curr_n_steps) * (1 / min_ratio) + (1 - step / curr_n_steps))
+    else:
+        lr = min_ratio
+
+    return lr
+
+
+def build_lr_fn(args: OptimArgs, n_steps: int):
+    if args.scheduler == "constant":
+        lr_fn = lambda x: 1.0
+    elif args.scheduler == "linear":
+        lr_fn = partial(
+            lr_linear, warmup=args.warmup, n_steps=n_steps, min_ratio=args.lr_min_ratio
+        )
+    elif args.scheduler == "inv_sqrt":
+        lr_fn = partial(
+            lr_inv_sqrt,
+            warmup=args.warmup,
+            exp_factor=args.exp_factor,
+            min_ratio=args.lr_min_ratio,
+        )
+    elif args.scheduler == "cosine":
+        lr_fn = partial(
+            lr_cosine,
+            warmup=args.warmup,
+            n_steps=n_steps,
+            cycle_length=args.cycle_length,
+            theta=args.cosine_theta,
+            min_ratio=args.lr_min_ratio,
+        )
+    elif args.scheduler == "wsd":
+        assert args.decay_fraction < args.cycle_length
+        lr_fn = partial(
+            lr_wsd,
+            warmup=args.warmup,
+            n_steps=n_steps,
+            decay_fraction=args.decay_fraction,
+            cycle_length=args.cycle_length,
+            min_ratio=args.lr_min_ratio,
+        )
+    else:
+        raise NotImplementedError(f"Unknown scheduler: {args.scheduler}")
+    return lr_fn
+
+
+def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int):
+    logger.info("Starting build of optimizer...")
+    optimizer = AdamW(
+        model.parameters(),
+        lr=args.lr,
+        betas=(args.beta1, args.beta2),
+        weight_decay=args.weight_decay,
+        eps=args.epsilon,
+        fused=True,  # Faster optim.step but can throw errors
+    )
+
+    # scheduler
+    lr_fn = build_lr_fn(args, n_steps)
+    scheduler = lr_scheduler.LambdaLR(optimizer, lr_fn)
+
+    logger.info("Done with build of optimizer.")
+    return optimizer, scheduler
diff --git a/bytelatent/plotting/__init__.py b/bytelatent/plotting/__init__.py
new file mode 100644
index 0000000..71ca4b1
--- /dev/null
+++ b/bytelatent/plotting/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
diff --git a/bytelatent/plotting/config_entropy_figure.yaml b/bytelatent/plotting/config_entropy_figure.yaml
new file mode 100644
index 0000000..4d7bfd7
--- /dev/null
+++ b/bytelatent/plotting/config_entropy_figure.yaml
@@ -0,0 +1,3 @@
+data_path: plot_data/entropy_figure.json
+chart_path: figures/entropy_figure.pdf
+# chart_path: figures/entropy_figure.pdf
diff --git a/bytelatent/plotting/config_scaling_figures.yaml b/bytelatent/plotting/config_scaling_figures.yaml
new file mode 100644
index 0000000..cda85c2
--- /dev/null
+++ b/bytelatent/plotting/config_scaling_figures.yaml
@@ -0,0 +1,4 @@
+df_dir: /home/par/blt_df
+output_chart_dir: figures/
+frame_files:
+  ["4b_df.json", "500m_df.json", "scaling_arch_df.json", "scaling_df.json"]
diff --git a/bytelatent/plotting/entropy_figure.py b/bytelatent/plotting/entropy_figure.py
new file mode 100644
index 0000000..c1966a1
--- /dev/null
+++ b/bytelatent/plotting/entropy_figure.py
@@ -0,0 +1,85 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import os
+import sys
+from pathlib import Path
+
+import altair as alt
+import pandas as pd
+from omegaconf import OmegaConf
+from pydantic import BaseModel
+
+
+class PlotEntropiesConfig(BaseModel):
+    data_path: str | None
+    chart_path: str
+
+    class Config:
+        extra = "forbid"
+
+
+class PlotEntropiesData(BaseModel):
+    text: str
+    threshold: float = 1.335442066192627
+    dataframe_json: str | None
+
+    class Config:
+        extra = "forbid"
+
+
+def main():
+    config_path = sys.argv[1]
+    file_config = OmegaConf.load(config_path)
+    # Omit program name and config file name
+    cli_conf = OmegaConf.from_cli(sys.argv[2:])
+    conf_dict = OmegaConf.to_container(
+        OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
+    )
+    plot_config = PlotEntropiesConfig(**conf_dict)
+    with open(plot_config.data_path) as f:
+        json_data = f.read()
+    plot_data = PlotEntropiesData.model_validate_json(json_data)
+    df = pd.read_json(plot_data.dataframe_json)
+
+    x_ticks = []
+    for row in df.itertuples():
+        position = row.position
+        token = row.tokens
+        x_ticks.append(f"{str(position).zfill(3)}|{token}")
+    df["position_with_token"] = x_ticks
+    print(df)
+
+    x_axis = alt.Axis(
+        labelExpr="split(datum.label, '|')[1]",
+        grid=False,
+        labelOverlap=False,
+        labelAngle=0,
+    )
+    width = 1200
+    height = 150
+    base = alt.Chart(df).properties(width=width, height=height)
+    points = base.mark_line(point=True).encode(
+        x=alt.X("position_with_token:O", title=None, axis=x_axis),
+        y=alt.Y(
+            "entropies",
+            title="Entropy of Next Byte",
+        ),
+    )
+    rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
+        y=alt.datum(plot_data.threshold),
+    )
+    patch_rules = (
+        alt.Chart(df[df["start"] > 0])
+        .properties(width=width, height=height)
+        .mark_rule(color="#474747", strokeDash=[4, 2])
+        .encode(x=alt.X("position_with_token:O", axis=x_axis))
+    )
+
+    chart = patch_rules + rule + points
+    chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)
+    path = Path(plot_config.chart_path)
+    path.parent.mkdir(exist_ok=True)
+    chart.save(path)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/bytelatent/plotting/scaling_figures.py b/bytelatent/plotting/scaling_figures.py
new file mode 100644
index 0000000..a49624a
--- /dev/null
+++ b/bytelatent/plotting/scaling_figures.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import sys
+from pathlib import Path
+
+import altair as alt
+import pandas as pd
+import pydantic
+from omegaconf import OmegaConf
+
+
+class ScalingPlotsConfig(pydantic.BaseModel):
+    df_dir: str
+    output_chart_dir: str
+    frame_files: list[str]
+
+    class Config:
+        extra = "forbid"
+
+
+def determine_family(key: str):
+    if key.startswith("Megabyte++"):
+        return "Megabyte++"
+    elif key.startswith("BLT"):
+        return "BLT"
+    elif key.startswith("LLaMA"):
+        return "LLaMA"
+    elif key.startswith("Space"):
+        return "Space"
+
+
+file_to_vars = {}
+
+
+def create_chart(df: pd.DataFrame, output_file: str):
+    df["metric"] = df["bpb/not_heldout.jsonl"]
+    df["family"] = df["key"].map(determine_family)
+    model_domain = [
+        "BLT Space ps=6",
+        "BLT Space w/o cross-attn",
+        "SpaceByte",
+        "LLaMA 3 BPE",
+        "Megabyte++ ps=4",
+        "Megabyte++ ps=6",
+    ]
+    color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"]
+    shape_range = [
+        "circle",
+        "square",
+        "cross",
+        "diamond",
+        "triangle-up",
+        "triangle-down",
+    ]
+    color_scale = alt.Scale(domain=model_domain, range=color_range)
+    shape_scale = alt.Scale(
+        domain=model_domain,
+        range=shape_range,
+    )
+    base_chart = alt.Chart(df).encode(
+        x=alt.X("flops", title="Training FLOPS")
+        .scale(type="log", domain=[2e20, 1.25e22])
+        .axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]),
+        y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False),
+    )
+    lines = base_chart.encode(
+        color=alt.Color("key", title="Model Color", scale=color_scale, legend=None),
+        strokeDash=alt.StrokeDash("family", title="Model Family", legend=None),
+    ).mark_line()
+    points = base_chart.encode(
+        color=alt.Color("key", title="Model", scale=color_scale),
+        shape=alt.Shape("key", title="", scale=shape_scale),
+    ).mark_point(size=70)
+    chart = (
+        (lines + points)
+        .resolve_scale(
+            color="independent",
+            shape="independent",
+            # strokeDash="independent",
+        )
+        .configure_legend(orient="right")
+        .properties(height=300, width=400)
+    )
+    print("Saving", output_file)
+    chart.save(output_file)
+
+
+def main():
+    config_path = sys.argv[1]
+    file_config = OmegaConf.load(config_path)
+    # Omit program name and config file name
+    cli_conf = OmegaConf.from_cli(sys.argv[2:])
+    conf_dict = OmegaConf.to_container(
+        OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
+    )
+    plot_config = ScalingPlotsConfig(**conf_dict)
+    df_dir = Path(plot_config.df_dir)
+    chart_dir = Path(plot_config.output_chart_dir)
+    chart_dir.mkdir(exist_ok=True, parents=True)
+    for ff in plot_config.frame_files:
+        path = df_dir / ff
+        df = pd.read_json(path)
+        print(df)
+        print(df.columns)
+        create_chart(df, chart_dir / f"{path.name}.pdf")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/bytelatent/preprocess/__init__.py b/bytelatent/preprocess/__init__.py
new file mode 100644
index 0000000..71ca4b1
--- /dev/null
+++ b/bytelatent/preprocess/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
diff --git a/bytelatent/preprocess/data_pipeline.py b/bytelatent/preprocess/data_pipeline.py
new file mode 100644
index 0000000..7dd5283
--- /dev/null
+++ b/bytelatent/preprocess/data_pipeline.py
@@ -0,0 +1,74 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import subprocess
+from pathlib import Path
+
+import luigi
+
+# CHANGEME: Change this to point to your data
+BASE_DIR = Path("datasets")
+DATASETS = ["dclm"]
+TARGET_DIR = Path("entropy_preprocess")
+
+SHARD_SCRIPT = """split -C 2500m -d {source} {destination}.shard_"""
+
+
+def list_dataset_shards(dataset: str):
+    dataset_dir = BASE_DIR / dataset
+    return list(dataset_dir.glob("*.chunk.*.jsonl"))
+
+
+class ChunkFile(luigi.ExternalTask):
+    file = luigi.Parameter()
+
+    def output(self):
+        return luigi.LocalTarget(self.file)
+
+
+class ShardDatasetChunk(luigi.Task):
+    dataset_name = luigi.Parameter()
+    chunk_file = luigi.Parameter()
+
+    def _chunk_filename(self):
+        return Path(self.chunk_file).name
+
+    def requires(self):
+        return ChunkFile(self.chunk_file)
+
+    def run(self):
+        destination_dir = TARGET_DIR / str(self.dataset_name)
+        destination_dir.mkdir(parents=True, exist_ok=True)
+        destination = destination_dir / self._chunk_filename()
+        subprocess.check_output(
+            SHARD_SCRIPT.format(source=str(self.chunk_file), destination=destination),
+            shell=True,
+        )
+        (
+            Path(TARGET_DIR)
+            / str(self.dataset_name)
+            / f"{self._chunk_filename()}.shard.COMPLETE"
+        ).touch()
+
+    def output(self):
+        return luigi.LocalTarget(
+            TARGET_DIR
+            / str(self.dataset_name)
+            / f"{self._chunk_filename()}.shard.COMPLETE"
+        )
+
+
+class ShardDataset(luigi.WrapperTask):
+    dataset_name = luigi.Parameter()
+
+    def requires(self):
+        for f in list_dataset_shards(self.dataset_name):
+            yield ShardDatasetChunk(dataset_name=self.dataset_name, chunk_file=str(f))
+
+
+class ShardAllDatasets(luigi.WrapperTask):
+    def requires(self):
+        for d in DATASETS:
+            yield ShardDataset(dataset_name=d)
+
+
+if __name__ == "__main__":
+    luigi.build([ShardAllDatasets()], local_scheduler=True, workers=128)
diff --git a/bytelatent/preprocess/parallel_entropies.py b/bytelatent/preprocess/parallel_entropies.py
new file mode 100644
index 0000000..ac2dbd4
--- /dev/null
+++ b/bytelatent/preprocess/parallel_entropies.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import subprocess
+from pathlib import Path
+
+import submitit
+import typer
+
+
+class PreprocessEntropiesJob(submitit.helpers.Checkpointable):
+    def __init__(self) -> None:
+        pass
+
+    def __call__(self, shard_file: str, output_filename: str):
+        subprocess.run(
+            [
+                "python",
+                "-u",
+                "-m",
+                "bytelatent.preprocess.preprocess_entropies",
+                str(shard_file),
+                str(output_filename),
+            ],
+            check=True,
+        )
+        return True
+
+
+def chunk(items, size):
+    for i in range(0, len(items), size):
+        yield items[i : i + size]
+
+
+def main(
+    job_folder: str,
+    input_dir: str,
+    output_dir: str,
+    qos: str = "explore",
+    slurm_batch_size: int = 1000,
+    check_only: bool = False,
+    wait: bool = False,
+):
+    input_dir = Path(input_dir)
+    output_dir = Path(output_dir)
+    shard_files = [
+        p for p in input_dir.glob("*.jsonl.shard*") if "COMPLETE" not in p.name
+    ]
+    if check_only:
+        exist = []
+        missing = []
+        for shard_file in shard_files:
+            shard_file = Path(shard_file)
+            complete_file = output_dir / f"{shard_file.name}.arrow.complete"
+            if complete_file.exists():
+                exist.append(complete_file)
+            else:
+                missing.append(complete_file)
+        print("Checked for output files for input_dir=", input_dir)
+        print("Exist:", len(exist))
+        print("Missing:", len(missing))
+        print(missing)
+        return
+    print("Running parallel job over N files=", len(shard_files))
+    print("Input Directory:", input_dir)
+    print("Output Directory:", output_dir)
+    output_dir.mkdir(exist_ok=True, parents=True)
+
+    executor = submitit.SlurmExecutor(job_folder)
+    executor.update_parameters(
+        # 12 hours in minutes
+        time=60 * 12,
+        qos=qos,
+        exclusive="user",
+        cpus_per_task=4,
+        num_gpus=1,
+        mem_per_gpu="80G",
+        array_parallelism=slurm_batch_size,
+    )
+
+    jobs = []
+    n_batches = 0
+    n_skipped = 0
+    n_launched = 0
+    for file_batch in chunk(shard_files, slurm_batch_size):
+        with executor.batch():
+            for shard_file in file_batch:
+                output_filename = Path(output_dir) / f"{shard_file.name}.arrow"
+                complete_output_filename = (
+                    Path(output_dir) / f"{shard_file.name}.arrow.complete"
+                )
+                if complete_output_filename.exists():
+                    n_skipped += 1
+                else:
+                    job = executor.submit(
+                        PreprocessEntropiesJob(), str(shard_file), str(output_filename)
+                    )
+                    n_launched += 1
+                    jobs.append(job)
+        n_batches += 1
+    print("launched array jobs n=", n_launched)
+    print("skipped (completed) array jobs n=", n_skipped)
+    print("number of slurm batches=", n_batches)
+    if wait:
+        output = [job.result() for job in jobs]
+        assert all(output)
+
+
+if __name__ == "__main__":
+    typer.run(main)
diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py
new file mode 100644
index 0000000..20d1e0c
--- /dev/null
+++ b/bytelatent/preprocess/preprocess_entropies.py
@@ -0,0 +1,141 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import time
+from pathlib import Path
+
+import numpy as np
+import pyarrow as pa
+import torch
+import typer
+from rich.progress import Progress, TextColumn
+
+from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
+
+
+def main(
+    input_file: str,
+    output_file: str,
+    patching_device: str = "cuda",
+    log_step: int = 10_000,
+    entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
+    dry_run: bool = False,
+):
+    # TODO: Modify this to work with the new code
+    raise NotImplementedError()
+    iterator = ArrowFileIterator(
+        file_path=input_file,
+        worker_id=0,
+        num_workers=1,
+    )
+    tokenization_mode = "bytes"
+    print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
+    print("Loading entropy model", entropy_model_checkpoint_dir)
+    if dry_run:
+        return
+    entropy_model = load_entropy_model(
+        entropy_model_checkpoint_dir, device=patching_device
+    )
+    entropy_model, _ = to_device(entropy_model, patching_device)
+    print("Creating patcher")
+    patching_batch_size = 32
+    print("Creating tokenizer")
+    tokenizer = Tokenizer(
+        model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
+        tokenization_mode=tokenization_mode,
+        # BYTE_UNITS
+        vocab_size_unit_1=256,
+        bos=True,
+        eos=True,
+        bpe_delim=False,
+        # This isn't used, just stores a reference for other calls we don't use
+        patcher=None,
+    )
+    step = 0
+    print("starting")
+    start_time = time.time()
+    patch_time = 0
+    entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
+    sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
+    text_field = pa.field("text", pa.string(), nullable=False)
+    schema = pa.schema([sample_id_field, text_field, entropy_field])
+    arrow_batch_size = 1_000
+
+    try:
+        with pa.OSFile(output_file, "wb") as sink:
+            with pa.ipc.new_file(sink, schema) as writer:
+                id_buffer = []
+                entropies_buffer = []
+                text_buffer = []
+                with Progress(
+                    *Progress.get_default_columns(),
+                    TextColumn("Completed: {task.completed}"),
+                ) as progress:
+                    task = progress.add_task(
+                        "[green]Calculating entropies...", total=None
+                    )
+                    for doc in iterator:
+                        sample_id = get_id_from_doc(doc)
+
+                        if "text" in doc:
+                            text = doc["text"]
+                        elif "content" in doc:
+                            text = doc["content"]
+                        else:
+                            raise ValueError(
+                                f"Could not find a text key from: {doc.keys()}"
+                            )
+                        tokens = torch.tensor(tokenizer.encode(text))
+                        patch_start = time.time()
+                        scores = calculate_entropies(
+                            tokens,
+                            entropy_model,
+                            patching_batch_size,
+                            patching_device,
+                        )
+                        entropies_buffer.append(
+                            np.array(scores.tolist(), dtype=np.float16)
+                        )
+                        id_buffer.append(sample_id)
+                        text_buffer.append(text)
+                        if len(entropies_buffer) == arrow_batch_size:
+                            batch = pa.record_batch(
+                                {
+                                    "entropies": entropies_buffer,
+                                    "sample_id": id_buffer,
+                                    "text": text_buffer,
+                                },
+                                schema,
+                            )
+                            writer.write(batch)
+                            entropies_buffer = []
+                            id_buffer = []
+                            text_buffer = []
+                        patch_time += time.time() - patch_start
+                        step += 1
+                        if step % log_step == 0:
+                            print("Completed steps:", step)
+                        progress.update(task, advance=1)
+                    if len(entropies_buffer) > 0:
+                        # Write last things
+                        batch = pa.record_batch(
+                            {
+                                "entropies": entropies_buffer,
+                                "sample_id": id_buffer,
+                                "text": text_buffer,
+                            },
+                            schema,
+                        )
+                        writer.write(batch)
+                        entropies_buffer = []
+                        id_buffer = []
+                        text_buffer = []
+        Path(f"{output_file}.complete").touch()
+    except:
+        Path(output_file).unlink(missing_ok=True)
+        raise
+    elapsed = time.time() - start_time
+    print("steps", step)
+    print("done in:", elapsed)
+
+
+if __name__ == "__main__":
+    typer.run(main)
diff --git a/bytelatent/probe.py b/bytelatent/probe.py
new file mode 100644
index 0000000..7bdc03c
--- /dev/null
+++ b/bytelatent/probe.py
@@ -0,0 +1,694 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# This file from the xFormers repo is just a example of how to implement
+# probing of the activations of a model, without changing anything.
+# By default, the linear inputs/outputs/gradients are logged, as well as
+# the attention logits+entropy. It is possible to log an additional tensor, eg:
+# x = log_stats(x, "name")
+#
+# Known limitations:
+# * Only a subset of the attention biases is supported
+# * Torch-compile is disabled automatically when this is enabled
+# * Only tested with bf16/f16/f32 datatypes
+
+import contextlib
+import functools
+import json
+import math
+import os
+import uuid
+from collections import defaultdict
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+    CheckpointImpl,
+    checkpoint_wrapper,
+)
+from torch.fx.operator_schemas import normalize_function
+from torch.nn.attention import SDPBackend, sdpa_kernel
+from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils._pytree import tree_map
+from torch.utils.module_tracker import ModuleTracker
+from xformers.ops import fmha
+
+
+@torch.library.custom_op("torchprobe::log", mutates_args=(), device_types=None)
+def _log(x: torch.Tensor, name: str, uid: str) -> None:
+    pass
+
+
+@_log.register_fake
+def _log_fake(x: torch.Tensor, name: str, uid: str) -> None:
+    pass
+
+
+class _LogStats(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x: torch.Tensor, name: str):
+        uid = str(uuid.uuid4())
+        torch.ops.torchprobe.log(x, name, uid)
+        ctx.name = name
+        ctx.uid = uid
+        return x
+
+    @staticmethod
+    def backward(ctx, grad: torch.Tensor):
+        torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid)
+        return grad, None
+
+
+_PROBING_ENABLED = False
+
+
+def log_stats(x: torch.Tensor, name: str) -> torch.Tensor:
+    if not _PROBING_ENABLED:
+        return x
+    return _LogStats.apply(x, name)
+
+
+QUANTILES = [
+    0.0000001,
+    0.000001,
+    0.00001,
+    0.0001,
+    0.001,
+    0.01,
+    0.05,
+    0.1,
+    0.3,
+    0.5,
+    0.7,
+    0.9,
+    0.95,
+    0.99,
+    0.999,
+    0.9999,
+    0.99999,
+    0.999999,
+    0.9999999,
+]
+
+
+@functools.cache
+def _get_quantiles(device: torch.device, dtype) -> torch.Tensor:
+    return torch.tensor(QUANTILES, device=device, dtype=dtype)
+
+
+def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]:
+    if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]:
+        return {}
+    x = x_.flatten()
+    if remove_inf:
+        x = x[x.abs() < float("inf")]
+    if x.dtype is not torch.double:
+        x = x.float()
+    xabs = x.abs()
+    quantiles = _get_quantiles(x.device, x.dtype)
+    mean = x.mean()
+    std = x.std()
+    return {
+        "shape": tuple(x_.shape),
+        "mean": mean,
+        "std": std,
+        "skew": (((x - mean) / std) ** 3).double().mean(),
+        "kurtosis": (((x - mean) / std) ** 4).double().mean(),
+        "abs.mean": xabs.mean(),
+        "max": x.max(),
+        "min": x.min(),
+        # Note: `quantile` takes at most 2**24 elements, see
+        # https://github.com/pytorch/pytorch/issues/64947
+        "quantiles": torch.quantile(x[: 2**24], quantiles),
+    }
+
+
+def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None:
+    assert logits.ndim == 4
+    logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf
+
+
+def _mask_attn_logits(
+    logits: torch.Tensor,
+    q_idx: List[int],
+    *,
+    causal: bool,
+    cu_seqlens_q: Optional[torch.Tensor] = None,
+    cu_seqlens_k: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+    assert logits.dtype is torch.float32
+    # Handle BlockDiagonalMask
+    if cu_seqlens_q is not None:
+        assert cu_seqlens_k is not None
+        # Expect BHMqMkv
+        assert logits.ndim == 4, logits.shape
+        qs = cu_seqlens_q.tolist()
+        ks = cu_seqlens_k.tolist()
+        q_batchid = []
+        k_batchid = [-2] * logits.shape[-1]
+        q_idx_i = 0
+        for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])):
+            for k in range(k0, k1):
+                k_batchid[k] = bid
+            while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1:
+                q_batchid.append(bid)
+                if causal:
+                    _mask_attn_causal_inplace(
+                        logits[:, :, q_idx_i : q_idx_i + 1, k0:k1],
+                        q_idx[q_idx_i] - q0,
+                        q1 - q0,
+                        k1 - k0,
+                    )
+                q_idx_i += 1
+        mask_out = (
+            torch.tensor(q_batchid, device=logits.device)[None, None, :, None]
+            != torch.tensor(k_batchid, device=logits.device)[None, None, None, :]
+        )
+        logits[mask_out.expand_as(logits)] = -math.inf
+        assert q_idx_i == len(q_idx)
+    elif causal:
+        for q_idx_i in range(len(q_idx)):
+            _mask_attn_causal_inplace(
+                logits[:, :, q_idx_i : q_idx_i + 1, :],
+                q_idx[q_idx_i],
+                logits.shape[2],
+                logits.shape[3],
+            )
+    return logits
+
+
+def _attn_queries_subset(num_queries: int) -> List[int]:
+    return list(range(0, num_queries, max(1, num_queries // 128)))
+
+
+@torch.no_grad()
+def _compute_attn_stats_sdpa(
+    probe,
+    path: str,
+    # supports arguments both cudnn + flash backends
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask=None,
+    attn_bias=None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+    compute_log_sumexp=True,
+    return_debug_mask=False,
+    **kwargs,
+):
+    if scale is None:
+        scale = 1 / (query.shape[-1] ** 0.5)
+    # Filter-out not supported cases
+    if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs:
+        probe.store[f"{path}::attn"] = {
+            "query.shape": tuple(query.shape),
+            "key.shape": tuple(key.shape),
+            "value.shape": tuple(value.shape),
+            "attn_mask": attn_mask.shape if attn_mask is not None else None,
+            "dropout_p": dropout_p,
+            "is_causal": is_causal,
+            "scale": scale,
+            "unk_kwargs": list(kwargs.keys()),
+        }
+        return
+    # Take a subset of the queries and compute the logits
+    query_s = _attn_queries_subset(query.shape[-2])
+    logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale
+    logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal)
+    p = logits.float().softmax(-1)
+    masked_logsoft = logits.log_softmax(-1).where(
+        (logits > -math.inf), torch.zeros_like(logits)
+    )
+    entropy = -(p * masked_logsoft).sum(-1)
+    probe.log_tensor(f"{path}::attn_entropy", entropy)
+    probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True)
+
+
+@torch.no_grad()
+def _compute_attn_stats_flash(
+    probe,
+    path: str,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    cu_seqlens_q: Optional[torch.Tensor],
+    cu_seqlens_k: Optional[torch.Tensor],
+    seqused_k: Optional[torch.Tensor],
+    max_seqlen_q: int,
+    max_seqlen_k: int,
+    p: float,
+    softmax_scale: float,
+    is_causal: bool,
+    window_left: int,
+    window_right: int,
+    return_softmax: bool,
+    block_tables: Optional[torch.Tensor],
+    unpadded_lse: bool = False,
+) -> None:
+    # Filter-out not supported cases
+    if (
+        seqused_k is not None
+        or p != 0.0
+        or window_left >= 0
+        or window_right >= 0
+        or block_tables is not None
+    ):
+        probe.store[f"{path}::attn"] = {
+            "query.shape": tuple(query.shape),
+            "key.shape": tuple(key.shape),
+            "value.shape": tuple(value.shape),
+            "op": "flash",
+        }
+        return
+
+    if cu_seqlens_q is not None:
+        assert query.ndim == 3, query.shape
+        query, key, value = query[None], key[None], value[None]
+    assert query.ndim == 4, query.shape
+
+    # Take a subset of the queries and compute the logits
+    query_s = _attn_queries_subset(query.shape[1])
+    logits = (
+        query[:, query_s].transpose(1, 2)
+        @ key.transpose(1, 2).transpose(-1, -2)
+        * softmax_scale
+    )
+    logits = _mask_attn_logits(
+        logits.float(),
+        query_s,
+        cu_seqlens_q=cu_seqlens_q,
+        cu_seqlens_k=cu_seqlens_k,
+        causal=is_causal,
+    )
+    p = logits.float().softmax(-1)
+    masked_logsoft = logits.log_softmax(-1).where(
+        (logits > -math.inf), torch.zeros_like(logits)
+    )
+    entropy = -(p * masked_logsoft).sum(-1)
+    probe.log_tensor(f"{path}::attn_entropy", entropy)
+    probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True)
+
+
+def _tensors_to_python(x):
+    if not isinstance(x, torch.Tensor):
+        return x
+    return x.tolist()
+
+
+# class syntax
+class LinearBwType(Enum):
+    DW = 1
+    DX = 2
+    UNKNOWN = 3
+
+
+class AutoProbeD(TorchDispatchMode):
+    def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None:
+        self.write_file = Path(write_file) if write_file is not None else None
+        self.write_tensors_tmpdir: Optional[Path] = None
+        self.compile_disabler = TorchCompileDisabler(module)
+        self.mod_tracker = ModuleTracker()
+        self.count_per_path: Dict[str, int] = defaultdict(int)
+        self.store: Dict[str, Dict[str, Any]] = {}
+        self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {}
+        self.uid_to_path: Dict[str, str] = {}
+        self.metadata: Any = None
+        self.enabled = False
+        self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0")))
+
+    def __enter__(self):
+        global _PROBING_ENABLED
+        assert not self.enabled, "Entered probe twice"
+        self.compile_disabler.__enter__()
+        self.mod_tracker.__enter__()
+        super().__enter__()
+        self.enabled = True
+        _PROBING_ENABLED = True
+        # self._setup_tensors_logging()
+        return self
+
+    def __exit__(self, *args) -> None:
+        global _PROBING_ENABLED
+        assert self.enabled, "Exiting probe without entering it"
+        super().__exit__(*args)
+        self.mod_tracker.__exit__(*args)
+        self.compile_disabler.__exit__(*args)
+        self._flush_and_clear()
+        _PROBING_ENABLED = False
+        self.enabled = False
+
+    def _setup_tensors_logging(self):
+        if self.write_file is not None:
+            self.write_file.parent.mkdir(exist_ok=True)
+            self.write_tensors_tmpdir = (
+                self.write_file.parent
+                / f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}"
+            )
+            self.write_tensors_tmpdir.mkdir(exist_ok=True)
+
+    def _flush_and_clear(self) -> None:
+        if self.write_file is not None:
+            dump_data = tree_map(_tensors_to_python, self.store)
+            with self.write_file.open("a") as fd:
+                json.dump(
+                    {
+                        "data": dump_data,
+                        "meta": self.metadata,
+                        "version": 2,
+                        "quantiles": QUANTILES,
+                    },
+                    fd,
+                )
+                fd.write("\n")
+        if self.write_tensors_tmpdir is not None:
+            assert self.write_file is not None
+            dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump"
+            dump_dir.mkdir(exist_ok=True)
+            dir_name = ""
+            if "it" in self.metadata:
+                dir_name = f"it{int(self.metadata['it']):010}"
+            if dir_name == "" or (dump_dir / dir_name).exists():
+                num_files = len(list(dump_dir.glob(f"{dir_name}v*")))
+                dir_name = f"{dir_name}v{num_files}"
+            dump_dir = dump_dir / dir_name
+            assert not dump_dir.exists()
+            self.write_tensors_tmpdir.rename(dump_dir)
+            self.write_tensors_tmpdir = None
+        self.store.clear()
+        self.count_per_path.clear()
+        self.uid_to_path.clear()
+
+    def _find_bw_path_and_type(
+        self, path: str, out: torch.Tensor, args
+    ) -> Tuple[str, LinearBwType]:
+        """
+        We are in the BW pass, and process a GEMM.
+        Let's figure out:
+        (1) The path for the FW pass (might differ in case of ModuleTracker bug)
+        (2) The type of BW pass (eg `dw` or `dx`)
+        """
+
+        def _is_path_correct_dw(path: str) -> bool:
+            # dW.t = dY.t @ X
+            in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path]
+            return out.shape == (w_shape[1], w_shape[0]) and torch.allclose(
+                input_sm, args[1][:4, :4]
+            )
+
+        def _is_path_correct_dx(path: str) -> bool:
+            # dX = dY @ W.t
+            in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path]
+            return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4])
+
+        if path in self.linear_data:
+            if _is_path_correct_dw(path):
+                return path, LinearBwType.DW
+            if _is_path_correct_dx(path):
+                return path, LinearBwType.DX
+        for candidate_path in self.mod_tracker.parents:
+            if candidate_path not in self.linear_data:
+                continue
+            if _is_path_correct_dw(candidate_path):
+                return candidate_path, LinearBwType.DW
+            if _is_path_correct_dx(candidate_path):
+                return candidate_path, LinearBwType.DX
+        return path, LinearBwType.UNKNOWN
+
+    def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None:
+        self.store[name] = _get_stats(x, **kwargs)
+        if self.write_tensors_tmpdir is not None:
+            name_safe = name.replace("::", "__").replace("/", "")
+            torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl")
+
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs if kwargs else {}
+        path = None
+        # Find longest path
+        for p in self.mod_tracker.parents:
+            if p == "Global":
+                continue
+            if path is None or len(p) > len(path):
+                path = p
+        if path is None:
+            path = "Global"
+        path = path.replace("._checkpoint_wrapped_module", "")
+        out = func(*args, **kwargs)
+
+        # Handle linear layers
+        if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]:
+            weight: torch.Tensor
+            input: torch.Tensor
+            if not self.mod_tracker.is_bw:
+                # (technically, weight is transposed)
+                if func._overloadpacket == torch.ops.aten.addmm:
+                    _bias, input, weight = args[:3]
+                else:
+                    assert func._overloadpacket == torch.ops.aten.mm
+                    input, weight = args[:2]
+                self.log_tensor(f"{path}::in", input)
+                self.log_tensor(f"{path}::w", weight)
+                self.log_tensor(f"{path}::out", out)
+                self.linear_data[path] = (
+                    input.shape,
+                    weight.shape,
+                    out.shape,
+                    input[:4, :4].clone(),
+                    weight[:4, :4].T.clone(),
+                )
+            elif func._overloadpacket == torch.ops.aten.mm:
+                # XXX: Try to find the actual path for the linear layer
+                # This is messed with with Francisco's FSDP sometimes
+                new_path, bwtype = self._find_bw_path_and_type(path, out, args)
+                if new_path != path:
+                    if self.verbose:
+                        print(f"E: Fixing path `{path}` -> `{new_path}")
+                    path = new_path
+
+                if bwtype == LinearBwType.DW:
+                    # dW.t = dY.t @ X
+                    self.log_tensor(f"{path}::w.g", out)
+                elif bwtype == LinearBwType.DX:
+                    # dX = dY @ W.t
+                    self.log_tensor(f"{path}::in.g", out)
+                    self.log_tensor(f"{path}::out.g", args[0])
+        elif func._overloadpacket in [
+            torch.ops.aten._scaled_dot_product_flash_attention,
+            torch.ops.aten._scaled_dot_product_cudnn_attention,
+        ]:
+            _, kwargs = normalize_function(
+                func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+            )
+            _compute_attn_stats_sdpa(self, path, **kwargs)
+        elif func._overloadpacket == fmha.flash.FwOp.OPERATOR:
+            _, kwargs = normalize_function(
+                func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+            )
+            _compute_attn_stats_flash(self, path, **kwargs)
+        elif func._overloadpacket == torch.ops.torchprobe.log:
+            uid = args[2]
+            path = self.uid_to_path.setdefault(uid, path)
+            self.log_tensor(f"{path}::{args[1]}", args[0])
+        if self.verbose:
+            print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}")
+        return out
+
+
+def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None:
+    if module._compiled_call_impl is not None:
+        out.append(module)
+    for c in module.children():
+        _find_all_submodules_compiled(out, module=c)
+
+
+class TorchCompileDisabler:
+    def __init__(self, module: nn.Module) -> None:
+        self.module = module
+        self.submodules_compiled: List[nn.Module] = []
+        self.compiled_call_impl: List[Any] = []
+        self.disable_compile = torch.compiler.disable()
+        torch._dynamo.config.raise_on_ctx_manager_usage = False  # type: ignore
+
+    def __enter__(self) -> None:
+        # Remove all `_compiled_call_impl` attributes to effectively
+        # "undo" compilation
+        self.submodules_compiled.clear()
+        _find_all_submodules_compiled(self.submodules_compiled, self.module)
+        self.compiled_call_impl = [
+            m._compiled_call_impl for m in self.submodules_compiled
+        ]
+        for m in self.submodules_compiled:
+            m._compiled_call_impl = None
+        self.disable_compile.__enter__()  # type: ignore
+
+    def __exit__(self, *args) -> None:
+        self.disable_compile.__exit__(*args)  # type: ignore
+        for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl):
+            m._compiled_call_impl = c_impl
+        self.compiled_call_impl = []
+
+
+Probe = AutoProbeD
+
+# EXAMPLE USAGE
+d = 512
+seqlen = 4
+bs = 2
+
+
+class Attention1(nn.Module):
+    def forward(self, x):
+        attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask()
+        return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape(
+            [x.shape[0], seqlen, -1]
+        )
+
+
+class Attention2(nn.Module):
+    def forward(self, x):
+        attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(
+            [seqlen] * bs
+        ).make_causal()
+        xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]])
+        return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape(
+            [x.shape[0], seqlen, -1]
+        )
+
+
+class AttentionSDPA(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.wo = nn.Linear(d, d)
+
+    def forward(self, x):
+        x = x.transpose(1, 2)
+        return self.wo(
+            F.scaled_dot_product_attention(x, x, x)
+            .transpose(1, 2)
+            .reshape([x.shape[0], seqlen, -1])
+        )
+
+
+class AttentionSDPAFlash(AttentionSDPA):
+    def forward(self, x):
+        x = x.transpose(1, 2)
+        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+            return self.wo(
+                F.scaled_dot_product_attention(x, x, x)
+                .transpose(1, 2)
+                .reshape([x.shape[0], seqlen, -1])
+            )
+
+
+class Model(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.head = nn.Linear(d, 16)
+        self.trunk = nn.Sequential(
+            nn.Linear(d, d),
+            nn.Linear(d, d),
+        )
+        self.q_proj = nn.Linear(d, d, bias=False)
+        self.trunk.compile()
+        self.attn1 = Attention1()
+        self.attn2 = Attention2()
+        self.attnSDPA = AttentionSDPA()
+        self.attnSDPAflash = AttentionSDPAFlash()
+
+    def forward(self, x):
+        B, nHeads, D = x.shape[0], d // 64, 64
+        x = self.q_proj(x).reshape([B, seqlen, nHeads, D])
+        x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x)
+        x = log_stats(x, "attns_out")
+        return self.head(self.trunk(x))
+
+
+def test_masking() -> None:
+    q_seqlen = [1, 1, 14, 12]
+    kv_seqlen = [2, 2, 14, 18]
+    attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
+        q_seqlen, kv_seqlen
+    ).make_causal_from_bottomright()
+    logits = torch.randn(
+        [1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda"
+    )
+    bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device)
+    logits_masked = logits.clone()
+    _mask_attn_logits(
+        logits_masked,
+        list(range(logits.shape[2])),
+        causal=True,
+        cu_seqlens_q=attn_bias.q_seqinfo.seqstart,
+        cu_seqlens_k=attn_bias.k_seqinfo.seqstart,
+    )
+    assert (logits + bias == logits_masked).all().item()
+
+
+def test_toy_model() -> None:
+    # Test masking
+    kw = dict(device="cuda", dtype=torch.float16)
+    x = torch.randn([bs, seqlen, d], **kw)
+    m = Model()
+    m.head = checkpoint_wrapper(
+        m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
+    )
+    m.to(**kw)
+    m.compile()
+    optim = torch.optim.SGD(m.parameters(), lr=0.0)
+    probe = AutoProbeD(m, "./probe.json")
+
+    for i in range(4):
+        with contextlib.ExitStack() as stack:
+            print(f"########### STEP {i}")
+            if i % 4 == 1:
+                stack.enter_context(probe)
+                probe.metadata = {"it": i}
+            y = m(x)
+            g = torch.randn_like(y)
+            y.backward(g)
+            if i % 4 == 1:
+                assert probe.enabled
+                # Make sure we registered all linears
+                print(list(probe.store.keys()))
+                for key in [
+                    "Model::attns_out",
+                    "Model::attns_out.g",
+                    "Model.attn1::attn_logits",
+                    "Model.attn2::attn_logits",
+                    "Model.attnSDPA::attn_logits",
+                    "Model.attnSDPAflash::attn_logits",
+                    "Model.head::w",
+                    "Model.head::w.g",
+                    "Model.head::in",
+                    "Model.head::in.g",
+                    "Model.head::out",
+                    "Model.head::out.g",
+                    "Model.trunk.0::in",
+                    "Model.trunk.1::in",
+                ]:
+                    assert key in probe.store, f"Missing key: '{key}'"
+                # .. and that the values are correct
+                for key, tensor in [
+                    ("Model.head::w", m.head.weight),
+                    ("Model.head::w.g", m.head.weight.grad),
+                    ("Model.q_proj::in", x),
+                    ("Model.q_proj::w.g", m.q_proj.weight.grad),
+                    ("Model.head::out", y),
+                    ("Model.head::out.g", g),
+                ]:
+                    assert key in probe.store, f"Missing key: '{key}'"
+                    assert torch.allclose(
+                        probe.store[key]["abs.mean"], tensor.float().abs().mean()
+                    ), f"'{key}' mismatches"
+                # Check we don't have `nans`
+                for key, value in probe.store.items():
+                    if "abs.mean" in value:
+                        assert math.isfinite(
+                            value["abs.mean"].item()
+                        ), f"Inf/Nan for {key}"
+            optim.step()
+            optim.zero_grad()
diff --git a/bytelatent/profiling.py b/bytelatent/profiling.py
new file mode 100644
index 0000000..da3c90d
--- /dev/null
+++ b/bytelatent/profiling.py
@@ -0,0 +1,133 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import contextlib
+import logging
+import os
+from pathlib import Path
+
+import torch.distributed
+import wandb
+import xformers.profiler
+from pydantic import BaseModel
+from torch.profiler.profiler import profile
+from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler
+
+from bytelatent.distributed import get_is_master
+
+
+class ProfilerArgs(BaseModel):
+    run: bool = False
+    trace_folder: str = "profiling"
+    mem_warmup: int = 100
+    mem_steps: int = 2
+    profile_warmup: int = 102
+    profile_steps: int = 2
+
+
+logger = logging.getLogger()
+
+
+def perfetto_to_html(json_file, html_file):
+    import gzip
+    import string
+
+    import viztracer
+
+    root = os.path.dirname(viztracer.__file__)
+    sub = {}
+    json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file)
+    with open(
+        os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8"
+    ) as f:
+        tmpl = f.read()
+    with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f:
+        sub["trace_viewer_full"] = f.read()
+    with json_file as j:
+        content = j.read()
+        if isinstance(content, bytes):
+            content = content.decode("utf-8")
+        sub["json_data"] = content.replace("</script>", "<\\/script>")  # type: ignore
+    with open(html_file, "w+", encoding="utf-8") as output_file:
+        output_file.write(string.Template(tmpl).substitute(sub))
+
+
+class PyTorchProfilerWandb(PyTorchProfiler):
+    def __init__(self, main_profiler) -> None:
+        self.main_profiler = main_profiler
+        self.num_steps = 0
+        self.pytorch_profiler = torch.profiler.profile(
+            on_trace_ready=self._on_trace,
+            profile_memory=True,
+            record_shapes=True,
+            # With stack gives huge profile traces
+            # and bugs out because of some non ascii
+            # character somewhere in pytorch
+            with_stack=False,
+            with_flops=True,
+            activities=self.ACTIVITIES,
+        )
+
+    def _analyze_trace(self, prof: profile):
+        logger.info("Begin analyze trace")
+        super()._analyze_trace(prof)
+        logger.info("End analyze trace")
+
+    def _on_trace(self, prof: torch.profiler.profiler.profile) -> None:
+        super()._on_trace(prof)
+        if get_is_master() and wandb.run is not None:
+            filename = list(
+                Path(self.main_profiler.output_dir).glob(
+                    "profile_CPU_CUDA*/*.pt.trace.json*"
+                )
+            )[0]
+            html_path = str(filename).replace(".json", ".html")
+            perfetto_to_html(filename, html_path)
+            wandb.log({"profile_trace": wandb.Html(html_path)})
+
+
+class MemSnapshotsProfilerWandb(MemSnapshotsProfiler):
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        super().__exit__(exc_type, exc_val, exc_tb)
+        if get_is_master() and wandb.run is not None:
+            filename = list(
+                Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html")
+            )[0]
+            wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)})
+
+
+@contextlib.contextmanager
+def maybe_run_profiler(dump_dir, module, config: ProfilerArgs):
+    # get user defined profiler settings
+
+    if config.run:
+        trace_dir = os.path.join(dump_dir, config.trace_folder)
+
+        logger.info(f"Profiling active.  Traces will be saved at {trace_dir}")
+
+        if get_is_master() and not os.path.exists(trace_dir):
+            os.makedirs(trace_dir)
+        if torch.distributed.is_initialized():
+            torch.distributed.barrier()
+
+        with xformers.profiler.profile(
+            output_dir=trace_dir,
+            module=module,
+            schedule=[
+                (
+                    MemSnapshotsProfilerWandb,
+                    config.mem_warmup,
+                    config.mem_warmup + config.mem_steps,
+                ),
+                (
+                    PyTorchProfilerWandb,
+                    config.profile_warmup,
+                    config.profile_warmup + config.profile_steps,
+                ),
+            ],
+        ) as profiler:
+            yield profiler
+
+    else:
+        torch_profiler = contextlib.nullcontext()
+        yield None
diff --git a/bytelatent/stool.py b/bytelatent/stool.py
new file mode 100644
index 0000000..965f4cb
--- /dev/null
+++ b/bytelatent/stool.py
@@ -0,0 +1,237 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import json
+import os
+import shutil
+import subprocess
+from dataclasses import dataclass
+from typing import Any, Dict
+
+from omegaconf import OmegaConf
+
+
+@dataclass
+class StoolArgs:
+    config: Any = None
+    launcher: str = "sbatch"  # Can be sbatch or bash if already in salloc
+    script: str = "apps.main.train"  # The script to run.
+    copy_code: bool = True  # Wether to copy code to dump dir
+    dirs_exists_ok: bool = (
+        False  # Wether to copy new code and config and run regardless that dir exists
+    )
+    override: bool = False  # Wether to delete dump dir and restart
+    nodes: int = -1  # The number of nodes to run the job on.
+    ngpu: int = 8  # The number of GPUs required per node.
+    ncpu: int = 16  # The number of CPUs allocated per GPU.
+    mem: str = ""  # The amount of memory to allocate.
+    anaconda: str = "default"  # The path to the anaconda environment.
+    constraint: str = ""  # The constraint on the nodes.
+    exclude: str = ""  # The nodes to exclude.
+    time: int = -1  # The time limit of the job (in minutes).
+    account: str = ""
+    qos: str = ""
+    partition: str = "learn"
+    stdout: bool = False
+
+
+SBATCH_COMMAND = """#!/bin/bash
+
+{exclude}
+{qos}
+{account}
+{constraint}
+#SBATCH --job-name={name}
+#SBATCH --nodes={nodes}
+#SBATCH --gres=gpu:{ngpus}
+#SBATCH --cpus-per-gpu={ncpu}
+#SBATCH --time={time}
+#SBATCH --partition={partition}
+#SBATCH --mem={mem}
+
+#SBATCH --output={dump_dir}/logs/%j/%j.stdout
+#SBATCH --error={dump_dir}/logs/%j/%j.stderr
+
+#SBATCH --open-mode=append
+#SBATCH --signal=USR2@120
+#SBATCH --distribution=block
+
+# Mimic the effect of "conda init", which doesn't work for scripts
+eval "$({conda_exe} shell.bash hook)"
+source activate {conda_env_path}
+
+{go_to_code_dir}
+
+export OMP_NUM_THREADS=1
+export LAUNCH_WITH="SBATCH"
+export DUMP_DIR={dump_dir}
+srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
+"""
+
+
+def copy_dir(input_dir: str, output_dir: str) -> None:
+    print(f"Copying : {input_dir}\n" f"to      : {output_dir} ...")
+    assert os.path.isdir(input_dir), f"{input_dir} is not a directory"
+    assert os.path.isdir(output_dir), f"{output_dir} is not a directory"
+    rsync_cmd = (
+        f"rsync -arm --copy-links "
+        f"--include '**/' "
+        f"--include '*.py' "
+        f"--exclude='*' "
+        f"{input_dir}/ {output_dir}"
+    )
+    print(f"Copying command: {rsync_cmd}")
+    subprocess.call([rsync_cmd], shell=True)
+    print("Copy done.")
+
+
+def retrieve_max_time_per_partition() -> Dict[str, int]:
+    # retrieve partition max times (a bit slow)
+
+    sinfo = json.loads(subprocess.check_output("sinfo --json", shell=True))["sinfo"]
+    max_times: Dict[str, int] = {}
+
+    for info in sinfo:
+        if info["partition"]["maximums"]["time"]["infinite"]:
+            max_times[info["partition"]["name"]] = 14 * 24 * 60  # 14 days
+        else:
+            max_times[info["partition"]["name"]] = info["partition"]["maximums"][
+                "time"
+            ][
+                "number"
+            ]  # in minutes
+
+    return max_times
+
+
+def validate_args(args) -> None:
+    # Set maximum time limit if not specified
+    if args.time == -1:
+        max_times = retrieve_max_time_per_partition()
+        args.time = max_times.get(
+            args.partition, 3 * 24 * 60
+        )  # Default to 3 days if not found
+        print(
+            f"No time limit specified, using max time for partitions: {args.time} minutes"
+        )
+
+    if args.constraint:
+        args.constraint = f"#SBATCH --constraint={args.constraint}"
+
+    if args.account:
+        args.account = f"#SBATCH  --account={args.account}"
+
+    if args.qos:
+        args.qos = f"#SBATCH --qos={args.qos}"
+
+    if getattr(args, "exclude", ""):
+        args.exclude = f"#SBATCH --exclude={args.exclude}"
+
+    if hasattr(args, "anaconda") and args.anaconda:
+        if args.anaconda == "default":
+            args.anaconda = (
+                subprocess.check_output("which python", shell=True)
+                .decode("ascii")
+                .strip()
+            )
+        else:
+            args.anaconda = f"{args.anaconda}/bin/python"
+        assert os.path.isfile(args.anaconda)
+
+    args.mem = args.mem or "0"
+
+    assert args.partition
+    assert args.ngpu > 0
+    assert args.ncpu > 0
+    assert args.nodes > 0
+    assert args.time > 0
+    assert args.partition
+
+
+def launch_job(args: StoolArgs):
+    # Set up args default and validate them depending on the cluster or partition requested
+    validate_args(args)
+    dump_dir = args.config["dump_dir"]
+    job_name = args.config["name"]
+    print("Creating directories...")
+    os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
+    if args.override:
+        confirm = input(
+            f"Are you sure you want to delete the directory '{dump_dir}'? This action cannot be undone. (yes/no): "
+        )
+        if confirm.lower() == "yes":
+            shutil.rmtree(dump_dir)
+            print(f"Directory '{dump_dir}' has been deleted.")
+        else:
+            print("Operation cancelled.")
+            return
+    if args.copy_code:
+        os.makedirs(f"{dump_dir}/code", exist_ok=args.dirs_exists_ok)
+        print("Copying code ...")
+        copy_dir(os.getcwd(), f"{dump_dir}/code")
+
+    print("Saving config file ...")
+    with open(f"{dump_dir}/base_config.yaml", "w") as cfg:
+        cfg.write(OmegaConf.to_yaml(args.config))
+
+    conda_exe = os.environ.get("CONDA_EXE", "conda")
+    conda_env_path = os.path.dirname(os.path.dirname(args.anaconda))
+    log_output = (
+        "-o $DUMP_DIR/logs/%j/%j_%t.out -e $DUMP_DIR/logs/%j/%j_%t.err"
+        if not args.stdout
+        else ""
+    )
+    sbatch = SBATCH_COMMAND.format(
+        name=job_name,
+        script=args.script,
+        dump_dir=dump_dir,
+        nodes=args.nodes,
+        tasks=args.nodes * args.ngpu,
+        nodes_per_run=args.nodes,
+        ngpus=args.ngpu,
+        ncpu=args.ncpu,
+        mem=args.mem,
+        qos=args.qos,
+        account=args.account,
+        constraint=args.constraint,
+        exclude=args.exclude,
+        time=args.time,
+        partition=args.partition,
+        conda_exe=conda_exe,
+        conda_env_path=conda_env_path,
+        log_output=log_output,
+        go_to_code_dir=f"cd {dump_dir}/code/" if args.copy_code else "",
+    )
+
+    print("Writing sbatch command ...")
+    with open(f"{dump_dir}/submit.slurm", "w") as f:
+        f.write(sbatch)
+
+    print("Submitting job ...")
+    os.system(f"{args.launcher} {dump_dir}/submit.slurm")
+
+    print("Done.")
+
+
+if __name__ == "__main__":
+    """
+    The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
+    This accepts arguments as a dot list
+    So if the dataclass looks like
+
+    @dataclass
+    class DummyArgs:
+        name: str
+        mode: LMTransformerArgs
+
+    @dataclass
+    class LMTransformerArgs:
+        dim: int
+
+    Then you can pass model.dim=32 to change values in LMTransformerArgs
+    or just name=tictac for top level attributes.
+    """
+    raise NotImplementedError("Update this to blt code")
+    args = OmegaConf.from_cli()
+    args.config = OmegaConf.load(args.config)
+    args = dataclass_from_dict(StoolArgs, args)
+    launch_job(args)
diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py
new file mode 100644
index 0000000..73ad9f7
--- /dev/null
+++ b/bytelatent/test_blt.py
@@ -0,0 +1,471 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import os
+from dataclasses import replace
+
+import numpy as np
+import pytest
+import torch
+
+from bytelatent.constants import BLT_DATA
+from bytelatent.data.data_types import Batch
+from bytelatent.data.ngram_processor import NgramProcessor
+from bytelatent.model.blt import (
+    ByteLatentTransformer,
+    ByteLatentTransformerArgs,
+    EmbeddingType,
+    compute_hash_embeddings,
+    create_global_transformer,
+    create_local_decoder,
+    create_local_encoder,
+    cross_attn_mask,
+    decoder_patch_ids_from_lengths,
+    get_blt_input,
+    init_embeddings,
+    patch_ids_from_lengths,
+)
+from bytelatent.model.transformer import CrossAttention
+from bytelatent.model.utils import create_causal_mask
+from bytelatent.optim import OptimArgs, build_optimizer
+from bytelatent.train import compute_loss
+
+
+def batch_to_tensors_and_gpu(batch):
+    x = torch.from_numpy(batch.x)
+    y = torch.from_numpy(batch.y)
+    mask = None if batch.mask is None else torch.from_numpy(batch.mask)
+    patch_lengths = (
+        None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths)
+    )
+    ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids)
+
+    if torch.cuda.is_available():
+        x = x.cuda()
+        y = y.cuda()
+        if mask is not None:
+            mask = mask.cuda()
+        if patch_lengths is not None:
+            patch_lengths = patch_lengths.cuda()
+        if ngram_ids is not None:
+            ngram_ids = ngram_ids.cuda()
+    return x, y, mask, patch_lengths, ngram_ids
+
+
+def fake_batch():
+    batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"))
+    del batch_dict["x2"]
+    del batch_dict["y2"]
+    del batch_dict["src_names"]
+    return Batch(**batch_dict)
+
+
+def create_args(cross_attention=False):
+    transformer_args = ByteLatentTransformerArgs(
+        # Base args provided
+        n_heads=8,
+        dim=512,
+        vocab_size=260,
+        # Additional args from command line
+        dim_token=256,
+        patch_size=6,
+        tokenization_mode="bytes",
+        patching_mode="space",
+        tie_local_encoder_decoder_logits=False,
+        data_loader_patching=True,
+        max_encoder_seq_length=12288,
+        pad_to_max_length=True,
+        encoder_lm_loss=False,
+        patching_threshold=3.1439168453216553,
+        encoder_hash_byte_group_size=[4],
+        encoder_hash_byte_group_vocab=50002,
+        encoder_hash_byte_group_nb_functions=3,
+        cross_attn_encoder=cross_attention,  # True,
+        cross_attn_decoder=cross_attention,  # True,
+        cross_attn_window_encoder=512,
+        cross_attn_window_decoder=512,
+        dim_local_encoder=256,
+        dim_local_decoder=256,
+        cross_attn_k=8,
+        cross_attn_nheads=4,
+        cross_attn_all_layers_decoder=True,
+        cross_attn_all_layers_encoder=True,
+        cross_attn_use_flex_attention=True,
+        cross_attn_init_by_pooling=True,
+        log_patch_lengths=True,
+        non_linearity="swiglu",
+        use_rope=True,
+        recompute_fc1_out=False,
+        recompute_fc3_out=False,
+        recompute_attn=False,
+        custom_bwd=False,
+        layer_ckpt="none",
+        efficient_attn="sdpa",
+        patch_only_encoder=False,
+        patch_only_decoder=False,
+        use_local_encoder_transformer=True,
+        init_use_gaussian=True,
+        init_use_depth="current",
+        attn_bias_type="block_causal",
+        alpha_depth="disabled",
+        max_length=256,
+        local_attention_window_len=512,
+        max_seqlen=12288,
+        downsampling_by_pooling="max",
+    )
+    return transformer_args
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
+class TestByteLatentTransformer:
+    def test_local_encoder(self):
+        args = create_args()
+        device = torch.device("cuda")
+        local_encoder = create_local_encoder(args).to(device)
+
+        batch = fake_batch()
+        tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
+
+        local_encoder_tokens, _, _ = get_blt_input(
+            tokens=tokens,
+            enforce_patch_size_multiple=False,
+            nb_boe=0,
+            patch_size=local_encoder.patch_size,
+            boe_id=local_encoder.boe_id,
+        )
+
+        patch_ids = patch_ids_from_lengths(
+            patch_lengths, local_encoder_tokens.shape[-1]
+        )
+
+        encoder_hash_tok_embedding = init_embeddings(
+            args,
+            EmbeddingType.HASH_TOK,
+            local_encoder_dim=local_encoder.dim,
+            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
+        ).to(device)
+
+        local_encoder_embeds = compute_hash_embeddings(
+            local_encoder_tokens=local_encoder_tokens,
+            local_encoder=local_encoder,
+            encoder_hash_tok_embedding=encoder_hash_tok_embedding,
+            encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
+            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
+            encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
+        )
+
+        reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt")
+        reference_tokens = torch.load(reference_path).to(device)
+        torch.testing.assert_close(
+            local_encoder_tokens,
+            reference_tokens,
+            msg="Generated tokens don't match reference tokens",
+        )
+
+        (h_encoder, h_cross), cache_encoder = local_encoder(
+            tokens=local_encoder_tokens,
+            embeds=local_encoder_embeds,
+            patch_embeds=None,
+            cross_mask=None,
+            num_patches=patch_lengths.shape[1],
+            patch_ids=patch_ids,
+        )
+
+        assert h_encoder is not None
+        assert h_cross is None
+        assert cache_encoder is None
+
+        expected_shape = (
+            local_encoder_tokens.shape[0],
+            local_encoder_tokens.shape[1],
+            local_encoder.dim,
+        )
+        assert h_encoder.shape == expected_shape
+
+    def test_local_encoder_cross_attention(self):
+        args = create_args(cross_attention=True)
+        device = torch.device("cuda")
+        local_encoder = create_local_encoder(args).to(device)
+
+        batch = fake_batch()
+        tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
+
+        local_encoder_tokens, _, _ = get_blt_input(
+            tokens=tokens,
+            enforce_patch_size_multiple=False,
+            nb_boe=0,
+            patch_size=local_encoder.patch_size,
+            boe_id=local_encoder.boe_id,
+        )
+
+        patch_ids = patch_ids_from_lengths(
+            patch_lengths, local_encoder_tokens.shape[-1]
+        )
+
+        encoder_hash_tok_embedding = init_embeddings(
+            args,
+            EmbeddingType.HASH_TOK,
+            local_encoder_dim=local_encoder.dim,
+            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
+        ).to(device)
+
+        cross_attn_mask_enc = cross_attn_mask(
+            patch_ids,
+            patch_lengths,
+            local_encoder_tokens.shape[-1],
+            patches_as_queries=True,
+            cross_attn_k=args.cross_attn_k,
+            window=args.cross_attn_window_encoder,
+            block_mask=True,
+        )
+
+        local_encoder_embeds = compute_hash_embeddings(
+            local_encoder_tokens=local_encoder_tokens,
+            local_encoder=local_encoder,
+            encoder_hash_tok_embedding=encoder_hash_tok_embedding,
+            encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
+            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
+            encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
+        )
+        (h_encoder, h_cross), cache_encoder = local_encoder(
+            tokens=local_encoder_tokens,
+            embeds=local_encoder_embeds,
+            patch_embeds=None,
+            cross_mask=cross_attn_mask_enc,
+            num_patches=patch_lengths.shape[1],
+            patch_ids=patch_ids,
+        )
+        assert h_encoder is not None
+        assert h_cross is not None
+        assert cache_encoder is None
+        expected_shape = (
+            local_encoder_tokens.shape[0],
+            local_encoder_tokens.shape[1],
+            local_encoder.dim,
+        )
+        assert h_encoder.shape == expected_shape
+        assert h_cross.shape == (2, 2048, local_encoder.dim)
+
+    def test_local_decoder_cross_attention(self):
+        args = create_args(cross_attention=True)
+        device = torch.device("cuda")
+        local_decoder = create_local_decoder(args).to(device)
+
+        test_files = {
+            "dec_embeds": "dec_embeds.pt",
+            "decoder_tokens": "local_decoder_tokens.pt",
+            "patch_embeds": "decoder_patch_cross_embeds.pt",
+        }
+        batch = fake_batch()
+        _, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
+
+        tensors = {
+            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
+            for name, filename in test_files.items()
+        }
+        decoder_patch_ids = decoder_patch_ids_from_lengths(
+            patch_lengths, 0, tensors["decoder_tokens"].shape[-1]
+        )
+        cross_attn_mask_dec = cross_attn_mask(
+            decoder_patch_ids,
+            patch_lengths,
+            tensors["decoder_tokens"].shape[-1],
+            patches_as_queries=False,
+            cross_attn_k=args.cross_attn_k,
+            window=args.cross_attn_window_decoder,
+            block_mask=True,
+        )
+        output, _ = local_decoder(
+            embeds=tensors["dec_embeds"],
+            patch_embeds=tensors["patch_embeds"],
+            tokens=tensors["decoder_tokens"],
+            cross_mask=cross_attn_mask_dec,
+            cache=None,
+        )
+        assert output is not None
+        assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size)
+
+    def test_local_decoder(self):
+        args = create_args()
+        device = torch.device("cuda")
+
+        local_decoder = create_local_decoder(args).to(device)
+
+        test_files = {
+            "dec_embeds": "dec_embeds.pt",
+            "decoder_tokens": "local_decoder_tokens.pt",
+            "patch_embeds": "decoder_patch_embeds.pt",
+        }
+
+        tensors = {
+            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
+            for name, filename in test_files.items()
+        }
+
+        output, cache_decoder = local_decoder(
+            embeds=tensors["dec_embeds"],
+            patch_embeds=tensors["patch_embeds"],
+            tokens=tensors["decoder_tokens"],
+            cross_mask=None,
+            cache=None,
+        )
+        assert output is not None
+        expected_shape = (
+            tensors["decoder_tokens"].shape[0],
+            tensors["decoder_tokens"].shape[1],
+            args.vocab_size,
+        )
+        assert output.shape == expected_shape
+        assert cache_decoder is None
+
+    def test_global_transformer(self):
+        args = create_args()
+        device = torch.device("cuda")
+        global_transformer = create_global_transformer(args).to(device)
+
+        test_files = {
+            "global_embeds": "global_embeds.pt",
+            "global_tokens": "global_tokens.pt",
+        }
+        tensors = {
+            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
+            for name, filename in test_files.items()
+        }
+        h, cache = global_transformer(
+            embeds=tensors["global_embeds"], tokens=tensors["global_tokens"]
+        )
+        h is not None
+        assert h.shape == (2, 256, 512)
+        assert cache is None
+
+    def test_blt_transformer_init(self):
+        args = create_args()
+        model = ByteLatentTransformer(args)
+        assert model is not None
+
+    @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"])
+    def test_blt_transformer_forward(self, attn_type):
+        args = create_args()
+        args = args.model_copy(update=dict(efficient_attn=attn_type))
+        model = ByteLatentTransformer(args)
+        model = model.cuda()
+        batch = fake_batch()
+        x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
+
+        output = model(
+            tokens=x,
+            patch_lengths=patch_lengths,
+            ngram_ids=ngram_ids,
+        )
+        assert output is not None
+        expected_shape = (
+            x.shape[0],
+            x.shape[1],
+            args.vocab_size,
+        )
+        assert output.shape == expected_shape
+
+    def test_blt_transformer_cross_attn_forward(self):
+        args = create_args(cross_attention=True)
+        model = ByteLatentTransformer(args)
+        model = model.cuda()
+        batch = fake_batch()
+        x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
+
+        output = model(
+            tokens=x,
+            patch_lengths=patch_lengths,
+            ngram_ids=ngram_ids,
+        )
+        assert output is not None
+        expected_shape = (
+            x.shape[0],
+            x.shape[1],
+            args.vocab_size,
+        )
+        assert output.shape == expected_shape
+
+    def test_cross_attention_rand(self):
+        x = torch.randn(2, 256, 512, device="cuda")
+        kv = torch.randn(2, 256, 512, device="cuda")
+        cross_attention = CrossAttention(
+            dim=512,
+            head_dim=64,
+            n_heads=8,
+            n_kv_heads=4,
+            norm_eps=1e-6,
+        ).to("cuda")
+        mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None)
+        output = cross_attention(x, kv, mask)
+        assert output is not None
+        assert output.shape == (2, 256, 512)
+
+    def test_ngram_embeddings(self):
+        ngram_to_size = {
+            2: 38396,
+            3: 50000,
+            4: 50000,
+            5: 50000,
+            6: 50000,
+            7: 50000,
+            8: 50000,
+        }
+        batch = fake_batch()
+        ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size)
+        ngram_ids = ngram_processor.encode_token_ngrams(batch.x)
+        ngram_ids = np.stack(ngram_ids, axis=0)
+        batch = replace(batch, ngram_ids=ngram_ids)
+        args = create_args(cross_attention=True)
+        args = args.model_copy(
+            update=dict(
+                encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000",
+                encoder_enable_byte_ngrams=True,
+                ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes,
+            )
+        )
+        model = ByteLatentTransformer(args)
+        model = model.cuda()
+        x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
+
+        output = model(
+            tokens=x,
+            patch_lengths=patch_lengths,
+            ngram_ids=ngram_ids,
+        )
+        assert output is not None
+        expected_shape = (
+            x.shape[0],
+            x.shape[1],
+            args.vocab_size,
+        )
+        assert output.shape == expected_shape
+
+    def test_loss_backward(self):
+        args = create_args()
+        args = args.model_copy(update=dict(efficient_attn="sdpa"))
+        batch = fake_batch()
+        model = ByteLatentTransformer(args)
+        steps = 10
+        optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps)
+        model = model.cuda()
+        x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
+
+        initial_loss = None
+        final_loss = None
+        for step in range(steps):
+            output = model(
+                tokens=x,
+                patch_lengths=patch_lengths,
+                ngram_ids=ngram_ids,
+            )
+            loss, _ = compute_loss(output, y, mask, 1.0)
+            if step == 0:
+                initial_loss = loss.item()
+            if step == steps - 1:
+                final_loss = loss.item()
+            prev_loss = loss.item()
+            loss.backward()
+            optimizer.step()
+            scheduler.step()
+            optimizer.zero_grad()
+        assert (
+            final_loss < initial_loss
+        ), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}"
diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py
new file mode 100644
index 0000000..3acc42d
--- /dev/null
+++ b/bytelatent/test_entropy_model.py
@@ -0,0 +1,55 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import os
+
+import torch
+
+from bytelatent.constants import BLT_DATA
+from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
+from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
+from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum, entropy
+from bytelatent.entropy_model import load_entropy_model
+from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
+
+ENTROPY_MODEL = "transformer_100m"
+ARROW_TEST_DATA = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
+
+
+def test_entropy_model():
+    initial_state = ArrowFileIteratorState(
+        file_path=None,
+        num_workers=1,
+        worker_id=0,
+        preprocess_dir=None,
+        entropy_model_name=ENTROPY_MODEL,
+        dataset_files=[ARROW_TEST_DATA],
+        row_num=0,
+        arrow_batch_size=100,
+    )
+    arrow_file = initial_state.build()
+    tokenizer_args = TokenizerArgs(
+        name="blt",
+        init_kwargs={
+            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
+        },
+    )
+    entropy_model = load_entropy_model(
+        BLT_DATA / "checkpoint_0100000_consolidated",
+        os.path.join(
+            BLT_DATA,
+            "entropy_model.pth",
+        ),
+    )
+    preprocess_iter = PreprocessIterator(
+        arrow_file,
+        tokenizer_args=tokenizer_args,
+        patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.entropy),
+        add_patches=False,
+    )
+    for example in preprocess_iter.create_iter():
+        tokens = torch.tensor(example.tokens).unsqueeze(0)
+        expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
+        preds = entropy_model(tokens)
+        pred_entropies = entropy(preds)
+        assert pred_entropies.shape == expected_entropies.shape
+        assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5)
+        break
diff --git a/bytelatent/tokenizers/__init__.py b/bytelatent/tokenizers/__init__.py
new file mode 100644
index 0000000..71ca4b1
--- /dev/null
+++ b/bytelatent/tokenizers/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
diff --git a/bytelatent/tokenizers/abstract_tokenizer.py b/bytelatent/tokenizers/abstract_tokenizer.py
new file mode 100644
index 0000000..fff26c3
--- /dev/null
+++ b/bytelatent/tokenizers/abstract_tokenizer.py
@@ -0,0 +1,19 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import abc
+
+
+class Tokenizer(abc.ABC):
+    @abc.abstractmethod
+    def encode(self, text: str, add_bos: bool, add_eos: bool):
+        pass
+
+    @abc.abstractmethod
+    def decode(self, tokens: list[int]):
+        pass
+
+    @abc.abstractmethod
+    def get_token_offsets(
+        self, text: str, tokens: list[int] | None = None
+    ) -> tuple[list[str], list[int]]:
+        """Return the offsets of the tokens in the original text. Only used for evaluation."""
+        pass
diff --git a/bytelatent/tokenizers/blt_tokenizer.py b/bytelatent/tokenizers/blt_tokenizer.py
new file mode 100644
index 0000000..a3462e1
--- /dev/null
+++ b/bytelatent/tokenizers/blt_tokenizer.py
@@ -0,0 +1,150 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import re
+
+from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
+from bytelatent.tokenizers.constants import (
+    BOE_ID,
+    BOS_ID,
+    BPE_ID,
+    BYTE_UNITS,
+    EOS_ID,
+    OFFSET,
+    PAD_ID,
+)
+from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
+
+
+def convert_to_bytes(s):
+    # check if the output is a bytes like object of the format <0x00>
+    if re.match(r"<0x[0-9a-fA-F]+>", s):
+        return bytes.fromhex(s[3:-1])
+    else:
+        return bytes(s, "utf-8", errors="ignore")
+
+
+def text2bytes_bpe_delims(
+    text: str,
+    *,
+    bpe_tokenizer,
+    bpe_id: int,
+    offsetting_special_char: int,
+    add_bos: bool,
+    add_eos: bool,
+):
+    cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos)
+    # merge the leading space tokens
+    leading_space_tokens = []
+    other_bpe_tokens = []
+    leading = True
+    for token in cur_bpe:
+        bpe_str = bpe_tokenizer.sp_model.id_to_piece(token)
+        if leading and all(c == "▁" for c in bpe_str):
+            leading_space_tokens.append(bpe_str)
+        else:
+            leading = False
+            other_bpe_tokens.append(bpe_str)
+    cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens
+
+    # Remove the '▁' characters
+    bpe_strs = []
+    for i, bpe_str in enumerate(cur_bpe_strs):
+        if (
+            len(bpe_strs) <= 1
+            and all([c == " " for s in bpe_strs for c in s])
+            and not all(c == "▁" for c in bpe_str)
+        ):
+            # Remove leading space for first non space token.
+            bpe_str = bpe_str.replace("▁", "")
+        elif i == 0 and all(c == "▁" for c in bpe_str):
+            bpe_str = " " * (len(text) - len(text.lstrip(" ")))
+        else:
+            bpe_str = bpe_str.replace("▁", " ")
+        if len(bpe_str) > 0:
+            bpe_strs.append(bpe_str)
+    ex_seq = []
+    # Convert bpe tokens to bytes
+    for s in bpe_strs:
+        byte_chunk = convert_to_bytes(s)
+        proc_chunk = [int(unit) for unit in byte_chunk]
+        ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk)
+
+    return ex_seq
+
+
+class BltTokenizer(Tokenizer):
+    def __init__(
+        self,
+        *,
+        vocab_size_unit_1: int = BYTE_UNITS,
+        bpe_delim: bool = False,
+        bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
+        add_bos: bool = True,
+        add_eos: bool = True,
+    ):
+        self.add_bos = add_bos
+        self.add_eos = add_eos
+        self.vocab_size_unit_1 = vocab_size_unit_1
+        self.boe_id = BOE_ID
+        self.bos_id = BOS_ID
+        self.eos_id = EOS_ID
+        self.pad_id = PAD_ID
+        self.bpe_id = BPE_ID
+        self.bpe_tokenizer_path = bpe_tokenizer_path
+        if bpe_delim:
+            self.bpe_tokenizer = SentencePieceTokenizer(
+                model_path=self.bpe_tokenizer_path
+            )
+        else:
+            self.bpe_tokenizer = None
+        self.bpe_delim = bpe_delim
+        self.offsetting_special_char = OFFSET
+        self.vocab_size_unit_1 = vocab_size_unit_1
+        self.n_words = vocab_size_unit_1 + self.offsetting_special_char
+
+    def encode(
+        self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
+    ):
+        if add_bos is None:
+            add_bos = self.add_bos
+        if add_eos is None:
+            add_eos = self.add_eos
+
+        if self.bpe_delim:
+            tokens = text2bytes_bpe_delims(
+                text,
+                bpe_tokenizer=self.bpe_tokenizer,
+                bpe_id=self.bpe_id,
+                offsetting_special_char=self.offsetting_special_char,
+                add_bos=False,
+                add_eos=False,
+            )
+        else:
+            tokens = bytes(text, encoding="utf-8", errors="ignore")
+
+        # Offsetting
+        tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
+
+        if add_bos:
+            tokens.insert(0, self.bos_id)
+        if add_eos:
+            tokens.append(self.eos_id)
+
+        return tokens
+
+    def decode(self, tokens: list[int], cut_at_eos: bool = False):
+        if cut_at_eos:
+            for k, t in enumerate(tokens):
+                if t == self.eos_id:
+                    tokens = tokens[: k + 1]
+                    break
+        return bytes(
+            [
+                tok - self.offsetting_special_char
+                for tok in tokens
+                if tok - self.offsetting_special_char >= 0
+            ]
+        ).decode("utf-8", errors="ignore")
+
+    def get_token_offsets(self, text: str, tokens: list[int] | None = None):
+        # TODO: Figure out what this does
+        raise NotImplementedError()
diff --git a/bytelatent/tokenizers/build_tokenizer.py b/bytelatent/tokenizers/build_tokenizer.py
new file mode 100644
index 0000000..8aa434d
--- /dev/null
+++ b/bytelatent/tokenizers/build_tokenizer.py
@@ -0,0 +1,69 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import logging
+from typing import Any
+
+from pydantic import BaseModel
+
+from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
+from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer
+from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer
+
+try:
+    from sentencepiece import SentencePieceProcessor
+
+    has_sp = True
+except ImportError:
+    has_sp = False
+
+try:
+    import tiktoken
+    from tiktoken.load import load_tiktoken_bpe
+
+    has_tiktoken = True
+except ImportError:
+    has_tiktoken = False
+
+from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
+from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
+
+logger = logging.getLogger(__name__)
+
+
+class MockTokenizer(Tokenizer):
+    n_words: int = 256
+
+    def encode(self, text: str, add_bos: bool, add_eos: bool):
+        return text
+
+    def decode(self, tokens):
+        raise NotImplementedError()
+
+    def get_token_offsets(
+        self, text: str, tokens: list[int] | None = None
+    ) -> tuple[list[str]]:
+        raise NotImplementedError()
+
+
+class TokenizerArgs(BaseModel):
+    name: str = "bytes"
+    init_kwargs: dict[str, Any] | None = None
+
+    def build(self) -> Tokenizer:
+        if self.init_kwargs is None:
+            init_kwargs = {}
+        else:
+            init_kwargs = self.init_kwargs
+        if self.name == "blt":
+            return BltTokenizer(**init_kwargs)
+        elif self.name == "bytes":
+            return ByteTokenizer(**init_kwargs)
+        elif self.name == "mock":
+            return MockTokenizer(**init_kwargs)
+        elif self.name == "sp":
+            assert has_sp, "sentencepiece not installed"
+            return SentencePieceTokenizer(**init_kwargs)
+        elif self.name == "tiktoken":
+            assert has_tiktoken, "tiktoken not installed"
+            return TikTokenTokenizer(**init_kwargs)
+        else:
+            raise NotImplementedError(f"{self.name} tokenizer type is not implemented")
diff --git a/bytelatent/tokenizers/byte_tokenizer.py b/bytelatent/tokenizers/byte_tokenizer.py
new file mode 100644
index 0000000..f85f4f7
--- /dev/null
+++ b/bytelatent/tokenizers/byte_tokenizer.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
+
+
+class ByteTokenizer(Tokenizer):
+    def __init__(self):
+        self.bos_id = 256
+        self.eos_id = 257
+        self.n_words = 258
+
+    def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
+        tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
+        return tokens
+
+    def decode(self, tokens: list[int]):
+        byte_tokens = bytes([t for t in tokens if t < 256])
+        return byte_tokens.decode("utf-8", errors="backslashreplace")
+
+    def get_token_offsets(
+        self, text: str, tokens: list[int] | None = None
+    ) -> tuple[list[str], list[int]]:
+        if tokens is None:
+            tokens = self.encode(text)
+
+        decoded_chars, offsets = [], []
+        byte_pos = 0
+        for token in tokens:
+            if token < 256:
+                char = bytes([token]).decode("utf-8", errors="ignore")
+                if char:
+                    decoded_chars.append(char)
+                    offsets.append(byte_pos)
+                byte_pos += len(char.encode("utf-8"))
+
+        return decoded_chars, offsets
diff --git a/bytelatent/tokenizers/constants.py b/bytelatent/tokenizers/constants.py
new file mode 100644
index 0000000..774119e
--- /dev/null
+++ b/bytelatent/tokenizers/constants.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+
+SEP = " "
+BOS_ID: int = 1
+EOS_ID: int = 2
+PAD_ID: int = -1
+BOE_ID: int = 0
+BPE_ID: int = 3
+OFFSET: int = 4
+
+BYTE_UNITS: int = 256
diff --git a/bytelatent/tokenizers/sentence_piece_tokenizer.py b/bytelatent/tokenizers/sentence_piece_tokenizer.py
new file mode 100644
index 0000000..faeb997
--- /dev/null
+++ b/bytelatent/tokenizers/sentence_piece_tokenizer.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import logging
+import os
+
+try:
+    from sentencepiece import SentencePieceProcessor
+
+    has_sp = True
+except ImportError:
+    has_sp = False
+
+from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
+
+logger = logging.getLogger(__name__)
+
+
+class SentencePieceTokenizer(Tokenizer):
+    def __init__(
+        self, model_path: str, add_bos: bool = True, add_eos: bool = True
+    ) -> None:
+        assert os.path.isfile(model_path), model_path
+        self.sp_model = SentencePieceProcessor(model_file=model_path)
+
+        logger.info(f"Reloaded SentencePiece model from {model_path}")
+
+        # BOS / EOS token IDs
+        self.n_words: int = self.sp_model.vocab_size()
+        self.bos_id: int = self.sp_model.bos_id()
+        self.eos_id: int = self.sp_model.eos_id()
+        self.pad_id: int = self.sp_model.pad_id()
+        self.add_bos = add_bos
+        self.add_eos = add_eos
+        logger.info(
+            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
+        )
+        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+    def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
+        if add_bos is None:
+            add_bos = self.add_bos
+
+        if add_eos is None:
+            add_eos = self.add_eos
+        assert type(s) is str
+        tokens = (
+            [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos
+        )
+        return tokens
+
+    def decode(self, tokens: list[int]):
+        return self.sp_model.decode(tokens)
+
+    def get_token_offsets(
+        self, text: str, tokens: list[int] | None = None
+    ) -> tuple[list[str], list[int]]:
+        pieces = self.sp_model.encode_as_immutable_proto(text).pieces
+        substrs = [p.surface for p in pieces]
+        offsets = [p.begin for p in pieces]
+        return substrs, offsets
diff --git a/bytelatent/tokenizers/test_blt_tokenizer.py b/bytelatent/tokenizers/test_blt_tokenizer.py
new file mode 100644
index 0000000..50308da
--- /dev/null
+++ b/bytelatent/tokenizers/test_blt_tokenizer.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import json
+
+from bytelatent.constants import BLT_DATA
+from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
+from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
+
+
+def test_tokenizer_bytes():
+    with open("fixtures/tokenizer_data.json") as f:
+        data = json.load(f)
+
+    examples: list[str] = data["texts"]
+    examples_tokens: list[list[int]] = data["tokens"]
+
+    tokenizer = BltTokenizer(bpe_delim=False)
+    for i in range(len(examples)):
+        assert tokenizer.encode(examples[i]) == examples_tokens[i]
+
+
+def test_tokenizer_bpe():
+    with open("fixtures/tokenizer_data_bpe_delim.json") as f:
+        data = json.load(f)
+
+    examples: list[str] = data["texts"]
+    examples_tokens: list[list[int]] = data["tokens"]
+
+    tokenizer = BltTokenizer(bpe_delim=True)
+    for i in range(len(examples)):
+        assert tokenizer.encode(examples[i]) == examples_tokens[i]
+
+
+def test_build_tokenizer_from_args():
+    tokenizer_args = TokenizerArgs(
+        name="blt",
+        init_kwargs={
+            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
+        },
+    )
+    tokenizer = tokenizer_args.build()
+    assert tokenizer.encode("test text") is not None
diff --git a/bytelatent/tokenizers/tiktoken_tokenizer.py b/bytelatent/tokenizers/tiktoken_tokenizer.py
new file mode 100644
index 0000000..f498bf2
--- /dev/null
+++ b/bytelatent/tokenizers/tiktoken_tokenizer.py
@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import logging
+from copy import copy
+from pathlib import Path
+
+from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
+
+try:
+    import tiktoken
+    from tiktoken.load import load_tiktoken_bpe
+
+    has_tiktoken = True
+except ImportError:
+    has_tiktoken = False
+DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
+DEFAULT_TIKTOKEN_SPECIAL_TOKENS = {
+    "<|begin_of_text|>": 0,
+    "<|end_of_text|>": 1,
+    "<|fim_prefix|>": 2,
+    "<|fim_middle|>": 3,
+    "<|fim_end_fill|>": 253,
+    "<|fim_pad|>": 254,
+    "<|fim_suffix|>": 255,
+}
+TIKTOKEN_MAX_ENCODE_CHARS = 400_000
+
+logger = logging.getLogger(__name__)
+
+
+class TikTokenTokenizer(Tokenizer):
+    def __init__(self, model_path: str) -> None:
+        mergeable_ranks = load_tiktoken_bpe(model_path)
+        all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS)
+        missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values())
+        for id in missing_ids:
+            all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id
+        for name in all_special_tokens_with_ids:
+            all_special_tokens_with_ids[name] += len(mergeable_ranks)
+
+        self.tkt_model = tiktoken.core.Encoding(
+            name=Path(model_path).stem,
+            pat_str=DEFAULT_TIKTOKEN_PATTERN,
+            mergeable_ranks=mergeable_ranks,
+            special_tokens=all_special_tokens_with_ids,
+        )
+
+        self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>")
+        self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>")
+
+        self.n_words: int = self.tkt_model.n_vocab
+
+        logger.info(
+            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
+        )
+
+    def encode(self, s: str, add_bos: bool, add_eos: bool):
+        assert isinstance(s, str)
+
+        subs = []
+        for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
+            subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
+        return (
+            [self.bos_id] * add_bos
+            + sum(self.tkt_model.encode_ordinary_batch(subs), start=[])
+            + [self.eos_id] * add_eos
+        )
+
+    def decode(self, tokens: list[int]):
+        return self.tkt_model.decode(tokens)
+
+    def get_token_offsets(
+        self, text: str, tokens: list[int] | None = None
+    ) -> tuple[list[str], list[int]]:
+        if tokens is not None:
+            token_bytes = self.tkt_model.decode_tokens_bytes(tokens)
+        else:
+            token_bytes = self.tkt_model.decode_tokens_bytes(
+                self.tkt_model.encode(text, allowed_special="all")
+            )
+
+        text_len, offsets = 0, []
+        for token in token_bytes:
+            offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
+            text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
+        substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])]
+        return substrs, offsets
diff --git a/bytelatent/train.py b/bytelatent/train.py
new file mode 100644
index 0000000..6cb13b9
--- /dev/null
+++ b/bytelatent/train.py
@@ -0,0 +1,651 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import gc
+import logging
+import os
+import sys
+from contextlib import ExitStack
+from copy import deepcopy
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from timeit import default_timer as timer
+from typing import Any, Dict, Type, TypeVar
+
+import torch
+import torch.distributed
+import torch.nn.functional
+import torch.nn.functional as F
+import wandb
+import xformers.profiler
+from omegaconf import OmegaConf
+from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint.stateful import Stateful
+from torch.optim import lr_scheduler
+
+from bytelatent.args import TrainArgs
+from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
+from bytelatent.data.data_types import DataLoaderState
+from bytelatent.distributed import (
+    check_model_value_range,
+    clean_env,
+    dist_mean_dict,
+    get_device_mesh,
+    get_is_master,
+    get_world_size,
+    init_signal_handler,
+    parallelize_model,
+    requeue_slurm_job,
+    setup_env,
+    setup_torch_distributed,
+)
+from bytelatent.logger import init_logger
+from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
+from bytelatent.model.blt import ByteLatentTransformer
+from bytelatent.optim import build_optimizer
+from bytelatent.probe import AutoProbeD
+from bytelatent.profiling import maybe_run_profiler
+from bytelatent.stool import StoolArgs, launch_job
+from bytelatent.transformer import (
+    build_fsdp_grouping_plan,
+    get_no_recompute_ops,
+    get_num_flop_per_token,
+    tp_parallelize,
+)
+
+logger = logging.getLogger()
+
+T = TypeVar("T")
+
+
+def flatten_dict(d, parent_key="", sep="_"):
+    items = []
+    for k, v in d.items():
+        new_key = f"{parent_key}{sep}{k}" if parent_key else k
+        if isinstance(v, dict):
+            items.extend(flatten_dict(v, new_key, sep=sep).items())
+        else:
+            items.append((new_key, v))
+    return dict(items)
+
+
+def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
+    """
+    Converts a dictionary to a dataclass instance, recursively for nested structures.
+    """
+    base = OmegaConf.structured(cls())
+    OmegaConf.set_struct(base, strict)
+    override = OmegaConf.create(data)
+    return OmegaConf.to_object(OmegaConf.merge(base, override))
+
+
+@dataclass
+class TrainState(Stateful):
+    step: int  # Nb of steps taken by the optimizer
+    acc_step: int  # Nb of accumulation steps done since last optimizer step
+    scheduler: lr_scheduler.LambdaLR
+    data_loader_state: DataLoaderState
+    scale: float = 1.0
+
+    def state_dict(self) -> Dict[str, Any]:
+        return {
+            "step": self.step,
+            "acc_step": self.acc_step,
+            "data_loader_state": self.data_loader_state.dict(),
+            "scheduler": self.scheduler.state_dict(),
+        }
+
+    def load_state_dict(self, state_dict):
+        self.step = state_dict["step"]
+        self.acc_step = state_dict["acc_step"]
+        self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"])
+        self.scheduler.load_state_dict(state_dict["scheduler"])
+
+
+def validate_train_args(args: TrainArgs, output_size: int):
+    if args.model.vocab_size < 0:
+        logger.info(f"Setting model output size to {args.model.vocab_size}")
+        args.model.vocab_size = output_size
+
+    assert args.dump_dir, "Dump dir not set"
+
+    if args.checkpoint.path is None:
+        logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
+        args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
+
+    for source in args.data.sources:
+        data_path = os.path.join(args.data.root_dir, source)
+        assert os.path.exists(data_path), f"{data_path} doesn't exist"
+
+    if (
+        args.distributed.dp_replicate
+        * args.distributed.dp_shard
+        * args.distributed.tp_size
+        != get_world_size()
+    ):
+        assert get_world_size() % args.distributed.dp_shard == 0
+        args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
+
+        assert args.distributed.dp_replicate % args.distributed.tp_size == 0
+        args.distributed.dp_replicate = (
+            args.distributed.dp_replicate // args.distributed.tp_size
+        )
+
+        logger.warning(
+            f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
+        )
+        assert (
+            args.distributed.dp_replicate
+            * args.distributed.dp_shard
+            * args.distributed.tp_size
+            == get_world_size()
+        )
+
+        if args.distributed.fsdp_type == "no_shard":
+            assert (
+                args.distributed.dp_shard == 1
+                and args.distributed.dp_replicate == get_world_size()
+            )
+
+    args.model.max_seqlen = args.data.seq_len
+
+    if args.distributed.tp_size == 1:
+        logger.warning(
+            "Tensor parallelism has not been tested for a while, use at your own risk"
+        )
+
+    assert (
+        args.probe_freq != args.profiling.mem_steps
+    ), "Don't profile during probe step"
+    assert (
+        args.probe_freq != args.profiling.profile_steps
+    ), "Don't profile during probe step"
+    if args.logging.wandb is not None:
+        args.logging.wandb.name = args.name
+
+    if args.probe_freq is not None:
+        assert (
+            args.distributed.tp_size == 1
+        ), "Probing not supported with tensor parallelism"
+        assert (
+            args.distributed.selective_activation_checkpointing is False
+        ), "Probing not supported with selective activation checkpointing"
+
+
+preemption_flag = dict(flag=False)
+
+
+def set_preemption_flag(signum, frame):
+    logger.warning("Signal handler called with signal " + str(signum))
+    logger.warning("Preemption ! checkpointing asap and exiting.")
+    preemption_flag["flag"] = True
+
+
+def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
+    test = train_state.step % freq == 0
+    if acc_step is not None:
+        test = test and (train_state.acc_step == acc_step)
+    elif acc_freq is not None:
+        test = test and ((train_state.acc_step % acc_freq) == 0)
+    return test
+
+
+def compute_loss(p, y, mask, scale):
+    tok_loss = scale * F.cross_entropy(
+        p.flatten(0, 1), y.flatten(0, 1), reduction="none"
+    )
+    if mask is None:
+        loss = tok_loss.mean()
+    else:
+        mask = mask.flatten(0, 1)
+        tok_loss = tok_loss * mask
+        loss = tok_loss.sum() / (mask.sum() + 1e-6)
+    return loss, tok_loss
+
+
+def train(args: TrainArgs):
+    with ExitStack() as context_stack:
+        tokenizer = args.data.tokenizer_args.build()
+        validate_train_args(
+            args,
+            tokenizer.n_words,
+        )
+        if get_is_master():
+            os.makedirs(args.dump_dir, exist_ok=True)
+            args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml")
+        init_logger(Path(args.dump_dir) / "train.log")
+        init_signal_handler(set_preemption_flag)  # For handling preemption signals.
+        setup_env(args.env)
+        setup_torch_distributed(args.distributed)
+        world_mesh = get_device_mesh(args.distributed)
+        logger.info(f"Starting job: {args.name}")
+
+        # build dataloader
+        # need dp world size and rank
+        dp_mesh = world_mesh["dp_replicate"]
+        dp_degree = dp_mesh.size()
+        dp_rank = dp_mesh.get_local_rank()
+        if args.distributed.dp_shard > 1:
+            dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
+            dp_degree *= world_mesh["dp_shard"].size()
+
+        logger.info(f"Running on dp rank : {dp_rank}")
+        logger.info(f"Running on dp size : {dp_degree}")
+
+        torch.manual_seed(args.seed)
+        logger.info("Building model")
+
+        # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
+        with torch.device("meta"):
+            model = ByteLatentTransformer(args.model)
+        logger.info("Model is built !")
+
+        model_param_count = get_num_params(model)
+
+        model = parallelize_model(
+            model,
+            world_mesh,
+            args.model,
+            args.distributed,
+            fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
+            tp_parallelize=tp_parallelize,
+            no_recompute_ops=get_no_recompute_ops(),
+        )
+
+        # Once we shard the model on different gpus we can actually initialize the model
+        # First we create empty tensors of the correct shapes
+        model = model.to_empty(device="cuda")
+        # Then we init the model. Please make sure this function initializes *ALL* parameters
+        # and buffers, otherwise you will have random values in the unitialized tensors
+        # which will silently fail (give nan gradients for example)
+
+        if args.checkpoint.init_ckpt_path:
+            logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
+            load_from_checkpoint(
+                args.checkpoint.init_ckpt_path, model, model_key="model"
+            )  # Put model_key="" if its directly the model checkpoint
+            model.rope_embeddings.reset_parameters()  # For RoPe initialization since it's a buffer it might not be loaded
+        else:
+            with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
+                torch.manual_seed(args.model.seed)
+                model.init_weights()
+        check_model_value_range(model, range=10.0, std=1.0)
+
+        # log model size
+
+        logger.info(f"Model size: {model_param_count:,} total parameters")
+
+        gpu_memory_monitor = GPUMemoryMonitor("cuda")
+        logger.info(
+            f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
+            f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
+        )
+        logger.info(f"GPU memory usage: {gpu_memory_monitor}")
+
+        # build optimizer after apply parallelisms to the model
+        optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
+        data_loader = args.data.build_from_rank(dp_rank, dp_degree)
+        data_loader_state = data_loader.get_state()
+
+        train_state = TrainState(
+            step=0,
+            acc_step=0,
+            data_loader_state=data_loader_state,
+            scheduler=scheduler,
+            scale=1.0,
+        )
+
+        checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
+        checkpoint.load(model, optimizer, train_state, world_mesh)
+        # Either load from latest checkpoint or start from scratch
+        if args.probe_freq is not None:
+            if get_is_master():
+                os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
+            torch.distributed.barrier()
+            probe = AutoProbeD(
+                model,
+                (
+                    Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
+                    if (dp_rank % 128 == 0)
+                    else None
+                ),
+            )
+            probe_mod = model._orig_mod if args.distributed.compile else model
+
+        gc.disable()
+
+        # train loop
+        model.train()
+        metric_logger = context_stack.enter_context(
+            MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
+        )
+        data_loader = train_state.data_loader_state.build()
+        batch_iterator = data_loader.create_iter()
+
+        torch_profiler = context_stack.enter_context(
+            maybe_run_profiler(args.dump_dir, model, args.profiling)
+        )
+
+        nwords_since_last_log = 0
+        time_last_log = timer()
+        gc.collect()
+        while train_state.step < args.steps:
+            # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
+            train_state.acc_step += 1
+            train_state.acc_step = train_state.acc_step % args.grad_acc_steps
+
+            # get batch
+            curr_lr = float(optimizer.param_groups[0]["lr"])
+            data_load_start = timer()
+            batch = next(batch_iterator)
+            batch_x = torch.from_numpy(
+                batch.x,
+            ).cuda()
+            batch_y = torch.from_numpy(batch.y).cuda()
+            batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
+            mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
+
+            if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None:
+                raise ValueError(
+                    "Cannot enable byte ngrams and have batch.ngram_ids be None"
+                )
+            ngram_ids = (
+                None
+                if batch.ngram_ids is None
+                else torch.from_numpy(batch.ngram_ids).cuda()
+            )
+
+            if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
+                logger.info("garbage collection")
+                # we do garbage collection manually otherwise different processes
+                # run the GC at different times so they slow down the whole pipeline
+                gc.collect()
+
+            data_load_time = round(timer() - data_load_start, 4)
+            nwords_since_last_log += batch_x.numel()
+
+            bsz, seqlen = batch_y.shape
+
+            # forward
+            start_timer = torch.cuda.Event(enable_timing=True)
+            end_timer = torch.cuda.Event(enable_timing=True)
+            start_timer.record()
+
+            # This is an automatic probe that will compute statistics
+            # of all linears' inputs, weights and outputs
+            # along with attention logits and entropy
+            # both in forward and backward pass
+            tok_loss = None
+            if (args.probe_freq is not None) and every_n_steps(
+                train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps
+            ):
+                # Here we do a fake forward and backward pass on a smaller
+                # batch size to avoid OOM
+                # This assumes the model has no stateful layers (batch norm..)
+                assert (
+                    next(probe_mod.parameters()).grad is None
+                ), "Can't probe model if grads are not reset"
+
+                with probe:
+                    probe.metadata = {
+                        "it": train_state.step,
+                        "global_step": train_state.step,
+                        "loop": "lingua",
+                    }
+                    # Non compiled model uses roughly 2x memory in our exps
+                    # So we divide bsz by 2 or seqlen by 2
+                    probe_bsz = max(1, bsz // 2)
+                    probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2)
+                    probe_loss = probe_mod(
+                        batch_x[:probe_bsz, :probe_seq],
+                        batch_y[:probe_bsz, :probe_seq],
+                    )
+                    probe_loss.backward()
+                    # We zero grads to cancel this fake step
+                    optimizer.zero_grad()
+
+                assert (
+                    next(probe_mod.parameters()).grad is None
+                ), "Probe model shouldn't have grads at this point"
+
+            pred = model(
+                batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
+            )
+
+            loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
+
+            # We scale loss with grad_acc_steps so the gradient is the same
+            # regardless of grad_acc_steps
+            loss = loss / args.grad_acc_steps
+
+            # backward on scaled loss to create scaled gradients
+            loss.backward()
+            # For logging we undo that scaling
+            loss = loss.detach() * args.grad_acc_steps
+
+            grad_norm = torch.nn.utils.clip_grad_norm_(
+                model.parameters(), max_norm=args.optim.clip, foreach=True
+            )
+
+            grad_norm = (
+                grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
+            ).item()
+
+            # optimizer step
+            if train_state.acc_step == 0:
+                optimizer.step()
+                scheduler.step()
+                optimizer.zero_grad()
+                train_state.step += 1
+
+            # updates the scale for next iteration
+            # training iteration complete
+            end_timer.record()
+
+            torch.cuda.synchronize()
+
+            curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
+
+            # if profiler is active
+            if torch_profiler:
+                xformers.profiler.step()
+
+            # log metrics
+            if every_n_steps(
+                train_state,
+                args.logging.freq,
+                acc_step=None if args.logging.acc_freq else 0,
+                acc_freq=args.logging.acc_freq,
+            ):
+                time_delta = timer() - time_last_log
+                wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
+
+                gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
+
+                total_acc_steps = (
+                    args.grad_acc_steps * train_state.step + train_state.acc_step
+                )
+                tokens_per_gpu = (
+                    total_acc_steps * args.data.batch_size * args.data.seq_len
+                )
+                total_tokens = dp_degree * tokens_per_gpu
+                # This is an estimate and the correct values may change
+                # if you change the architecture
+                # Use xformer's analyze profile trace to get actual measurement
+                FLOPS = (
+                    get_num_flop_per_token(
+                        model_param_count - args.model.vocab_size * args.model.dim,
+                        args.model.n_layers,
+                        args.model.dim,
+                        args.data.seq_len,
+                    )
+                    * wps
+                )
+
+                metrics = flatten_dict(
+                    {
+                        "global_step": train_state.step,
+                        "acc_step": train_state.acc_step,
+                        "speed": {
+                            "wps": wps,
+                            "FLOPS": FLOPS,
+                            "curr_iter_time": curr_iter_time,
+                            "data_load_time": data_load_time,
+                        },
+                        "optim": {
+                            "grad_norm": grad_norm,
+                            "lr": curr_lr,
+                            "total_tokens": total_tokens,
+                        },
+                        "memory": gpu_mem_stats._asdict(),
+                    },
+                    sep="/",
+                )
+
+                to_sync = {}
+                to_sync["loss/out"] = loss.item()
+                metrics.update(dist_mean_dict(to_sync))
+
+                if get_is_master():
+                    metric_logger.log(metrics)
+
+                gpu_memory_monitor.reset_peak_stats()
+                nwords_since_last_log = 0
+                time_last_log = timer()
+                logger.info(
+                    f"step: {train_state.step}"
+                    f"  acc: {train_state.acc_step}"
+                    f"  loss: {round(loss.item(),4):>7}"
+                    f"  grad: {grad_norm:.2e}"
+                    f"  flops: {FLOPS:.2e}"
+                    f"  wps: {wps:.2e}"
+                    f"  iter: {curr_iter_time:>7}"
+                    f"  data: {data_load_time:>5}"
+                    f"  lr: {curr_lr:.2e}"
+                    f"  mem: {gpu_mem_stats.max_active_pct:.0f}%"
+                    f"  pow: {gpu_mem_stats.power_draw/1000} W"
+                )
+
+            saved = False
+            if every_n_steps(
+                train_state, args.checkpoint.dump.every, acc_step=0
+            ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
+                saved = checkpoint.save(
+                    model,
+                    optimizer,
+                    train_state,
+                    args,
+                    device_mesh=world_mesh,
+                )
+
+            if args.eval is not None and every_n_steps(
+                train_state, args.checkpoint.eval.every, acc_step=0
+            ):
+                from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
+
+                eval_args = dataclass_from_dict(EvalArgs, args.eval)
+
+                eval_args.global_step = train_state.step
+                eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
+                eval_args.dump_dir = str(
+                    os.path.join(
+                        args.dump_dir,
+                        "evals",
+                        EVAL_FOLDER_NAME.format(train_state.step),
+                    )
+                )
+                eval_args.metric_log_dir = args.dump_dir
+                if args.async_eval_gpus is None:
+                    launch_eval(eval_args)
+                elif get_is_master():
+                    if wandb.run is not None and args.logging.wandb is not None:
+                        eval_args.wandb = deepcopy(args.logging.wandb)
+                    assert args.async_eval_gpus > 0
+                    logger.info(f"Launching evals on {args.async_eval_gpus} gpus")
+                    with clean_env():
+                        launch_job(
+                            StoolArgs(
+                                asdict(eval_args),
+                                script="apps.main.eval",
+                                copy_code=False,
+                                nodes=args.async_eval_gpus // 8,
+                                qos="lowest",
+                            )
+                        )
+
+            if preemption_flag["flag"]:
+                if not saved:
+                    checkpoint.save(
+                        model,
+                        optimizer,
+                        train_state,
+                        args,
+                        device_mesh=world_mesh,
+                    )
+                requeue_slurm_job()
+                sys.exit(0)
+
+    if not saved:
+        checkpoint.save(
+            model,
+            optimizer,
+            train_state,
+            args,
+            device_mesh=world_mesh,
+        )
+    gc.collect()
+
+
+def main():
+    """
+    The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
+    This accepts arguments as a dot list
+    So if the dataclass looks like
+
+    @dataclass
+    class DummyArgs:
+        name: str
+        model: LMTransformerArgsgs
+
+    @dataclass
+    class LMTransformerArgsgs:
+        dim: int
+
+    Then you can pass model.dim=32 to change values in LMTransformerArgsgs
+    or just name=tictac for top level attributes.
+
+    The behavior here is as follows:
+    1. We instantiate TrainArgs with its default values
+    2. We override those default values with the ones in the provided config file
+    3. We override the result with the additional arguments provided through command line
+
+    For example, if the config is the following
+
+    model:
+        dim: 128
+        n_layers: 4
+
+    and you call train.py with train.py model.dim=64
+
+    Then the final TrainArgs will have
+
+    model:
+        dim: 64
+        n_layers: 4
+
+    Plus all the default values in TrainArgs dataclass.
+    """
+    cli_args = OmegaConf.from_cli()
+    file_cfg = OmegaConf.load(cli_args.config)
+    # We remove 'config' attribute from config as the underlying DataClass does not have it
+    del cli_args.config
+
+    default_cfg = OmegaConf.create(TrainArgs().model_dump())
+    cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
+    cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
+    train_args = TrainArgs.model_validate(cfg)
+    train(train_args)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py
new file mode 100644
index 0000000..432f7df
--- /dev/null
+++ b/bytelatent/transformer.py
@@ -0,0 +1,216 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.distributed._tensor import Replicate, Shard
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    PrepareModuleInput,
+    RowwiseParallel,
+    SequenceParallel,
+    parallelize_module,
+)
+from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+from xformers.ops import AttentionBias, fmha
+
+from bytelatent.base_transformer import (
+    BaseTransformer,
+    BaseTransformerArgs,
+    RMSNorm,
+    cross_entropy,
+)
+
+
+def create_causal_mask(seqlen, attn_impl, sliding_window):
+    if sliding_window is not None and attn_impl == "xformers":
+        return fmha.attn_bias.LocalAttentionFromBottomRightMask(
+            window_left=sliding_window - 1, window_right=0
+        )
+    elif attn_impl == "xformers":
+        return fmha.attn_bias.LowerTriangularMask()
+    elif attn_impl == "sdpa":
+        return "causal"
+    elif attn_impl == "flex_attention":
+        return create_block_mask(causal_mask, None, None, seqlen, seqlen)
+    else:
+        raise NotImplementedError(
+            f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
+        )
+
+
+def attention_flops_per_token(n_layers, seq_len, dim, causal):
+    # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
+    return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
+
+
+def get_num_flop_per_token(
+    num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
+) -> int:
+    return 6 * num_non_embed_params + attention_flops_per_token(
+        n_layers, seq_len, dim, True
+    )
+
+
+def causal_mask(b, h, q_idx, kv_idx):
+    return q_idx >= kv_idx
+
+
+class LMTransformerArgs(BaseTransformerArgs):
+    seed: int = 42
+
+    vocab_size: int = -1
+    weight_tying: bool = False
+
+    sliding_window: int | None = None
+
+
+class LMTransformer(BaseTransformer):
+    def __init__(self, args: LMTransformerArgs):
+        super().__init__(args)
+        self.weight_tying = args.weight_tying
+        self.sliding_window = args.sliding_window
+
+        assert args.vocab_size > 0
+
+        self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
+
+        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+        self.output = nn.Linear(
+            args.dim,
+            args.vocab_size,
+            bias=False,
+        )
+
+        if args.weight_tying:
+            self.output.weight = self.embeddings.tok_embeddings.weight
+
+    def forward(
+        self,
+        token_values: torch.Tensor,
+        target: Optional[torch.Tensor] = None,
+        tok_idx: Optional[torch.Tensor] = None,
+        mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
+        attn_impl: str = "sdpa",
+    ):
+        bsz, seqlen = token_values.shape
+
+        h = self.tok_embeddings(token_values)
+
+        mask = (
+            mask
+            if mask is not None
+            else create_causal_mask(seqlen, attn_impl, self.sliding_window)
+        )
+        h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
+
+        logits = self.output(self.norm(h))
+        if target is not None:
+            return cross_entropy(logits, target)
+        else:
+            return logits
+
+    def reset_parameters(self, init_std=None):
+        # Either use fixed base std or sqrt model dim
+        super().reset_parameters()
+        init_std = init_std or (self.dim ** (-0.5))
+        self.norm.reset_parameters()
+        nn.init.trunc_normal_(
+            self.tok_embeddings.weight,
+            mean=0.0,
+            std=init_std,
+            a=-3 * init_std,
+            b=3 * init_std,
+        )
+        if not self.weight_tying:
+            nn.init.trunc_normal_(
+                self.output.weight,
+                mean=0.0,
+                std=init_std,
+                a=-3 * init_std,
+                b=3 * init_std,
+            )
+
+
+# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
+def get_no_recompute_ops():
+    return None
+
+
+# Optional and only used for fully shard options (fsdp) is choose. Highly recommanded for large models
+def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
+    group_plan: Tuple[int, bool] = []
+
+    # Grouping and output seperately
+    group_plan.append(("tok_embeddings", False))
+
+    # Grouping by layers
+    for i in range(model_args.n_layers):
+        group_plan.append((f"layers.{i}", False))
+
+    group_plan.append(("output", True))
+
+    return group_plan
+
+
+# Optional and only used for model/tensor parallelism when tp_size > 1
+def tp_parallelize(model, tp_mesh, model_args: LMTransformerArgs, distributed_args):
+    assert model_args.dim % distributed_args.tp_size == 0
+    assert model_args.vocab_size % distributed_args.tp_size == 0
+    assert model_args.n_heads % distributed_args.tp_size == 0
+    assert (model_args.n_kv_heads or 0) % distributed_args.tp_size == 0
+    assert model_args.n_heads % (model_args.n_kv_heads or 1) == 0
+
+    # Embedding layer tp
+    main_plan = {}
+    main_plan["tok_embeddings"] = ColwiseParallel(
+        input_layouts=Replicate(), output_layouts=Shard(1)
+    )
+    main_plan["norm"] = SequenceParallel()
+    main_plan["output"] = ColwiseParallel(
+        input_layouts=Shard(1), output_layouts=Replicate()
+    )
+
+    parallelize_module(
+        model,
+        tp_mesh,
+        main_plan,
+    )
+
+    # Attention layers tp
+    for layer in model.layers:
+        layer_plan = {}
+
+        layer_plan["attention"] = PrepareModuleInput(
+            input_layouts=(Shard(1), None),
+            desired_input_layouts=(Replicate(), None),
+        )
+        layer_plan["attention_norm"] = SequenceParallel()
+        layer_plan["attention.wq"] = ColwiseParallel()
+        layer_plan["attention.wk"] = ColwiseParallel()
+        layer_plan["attention.wv"] = ColwiseParallel()
+        layer_plan["attention.wo"] = RowwiseParallel(output_layouts=Shard(1))
+
+        # Feedforward layers tp
+        layer_plan["feed_forward"] = PrepareModuleInput(
+            input_layouts=(Shard(1),),
+            desired_input_layouts=(Replicate(),),
+        )
+        layer_plan["ffn_norm"] = SequenceParallel()
+        layer_plan["feed_forward.w1"] = ColwiseParallel()
+        layer_plan["feed_forward.w3"] = ColwiseParallel()
+        layer_plan["feed_forward.w2"] = RowwiseParallel(output_layouts=Shard(1))
+
+        parallelize_module(
+            layer,
+            tp_mesh,
+            layer_plan,
+        )
+
+        # Adjusting the number of heads and kv heads according to the tp size
+        attn_layer = layer.attention
+        attn_layer.n_heads = attn_layer.n_heads // distributed_args.tp_size
+        attn_layer.n_kv_heads = attn_layer.n_kv_heads // distributed_args.tp_size
diff --git a/dev/lint.sh b/dev/lint.sh
new file mode 100644
index 0000000..cde5ce1
--- /dev/null
+++ b/dev/lint.sh
@@ -0,0 +1,3 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+isort .
+black . 
diff --git a/fixtures/tokenizer_data.json b/fixtures/tokenizer_data.json
new file mode 100644
index 0000000..1aeb9e2
--- /dev/null
+++ b/fixtures/tokenizer_data.json
@@ -0,0 +1 @@
+{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 80, 105, 120, 43, 119, 36, 103, 108, 105, 103, 111, 36, 109, 106, 36, 120, 108, 105, 119, 105, 36, 120, 115, 111, 105, 114, 109, 126, 105, 118, 119, 36, 113, 101, 120, 103, 108, 37, 2]]}
\ No newline at end of file
diff --git a/fixtures/tokenizer_data_bpe_delim.json b/fixtures/tokenizer_data_bpe_delim.json
new file mode 100644
index 0000000..c726f6e
--- /dev/null
+++ b/fixtures/tokenizer_data_bpe_delim.json
@@ -0,0 +1 @@
+{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 3, 80, 105, 120, 3, 43, 3, 119, 3, 36, 103, 108, 105, 103, 111, 3, 36, 109, 106, 3, 36, 120, 108, 105, 119, 105, 3, 36, 120, 115, 111, 105, 114, 3, 109, 126, 105, 118, 119, 3, 36, 113, 101, 120, 103, 108, 3, 37, 2]]}
\ No newline at end of file
diff --git a/fixtures/tokens_with_entropies.json b/fixtures/tokens_with_entropies.json
new file mode 100644
index 0000000..e3589d8
--- /dev/null
+++ b/fixtures/tokens_with_entropies.json
@@ -0,0 +1 @@
+{"position":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":6,"7":7,"8":8,"9":9,"10":10,"11":11,"12":12,"13":13,"14":14,"15":15,"16":16,"17":17,"18":18,"19":19,"20":20,"21":21,"22":22,"23":23,"24":24,"25":25,"26":26,"27":27,"28":28,"29":29,"30":30,"31":31,"32":32,"33":33,"34":34,"35":35,"36":36,"37":37,"38":38,"39":39,"40":40,"41":41,"42":42,"43":43,"44":44,"45":45,"46":46,"47":47,"48":48,"49":49,"50":50,"51":51,"52":52,"53":53,"54":54,"55":55,"56":56,"57":57,"58":58,"59":59,"60":60,"61":61,"62":62,"63":63,"64":64,"65":65,"66":66,"67":67,"68":68,"69":69,"70":70,"71":71,"72":72,"73":73,"74":74,"75":75,"76":76,"77":77,"78":78,"79":79,"80":80},"tokens":{"0":"<","1":"D","2":"a","3":"e","4":"n","5":"e","6":"r","7":"y","8":"s","9":"_","10":"T","11":"a","12":"r","13":"g","14":"a","15":"r","16":"y","17":"e","18":"n","19":"_","20":"i","21":"s","22":"_","23":"i","24":"n","25":"_","26":"G","27":"a","28":"m","29":"e","30":"_","31":"o","32":"f","33":"_","34":"T","35":"h","36":"r","37":"o","38":"n","39":"e","40":"s","41":",","42":"_","43":"a","44":"_","45":"f","46":"a","47":"n","48":"t","49":"a","50":"s","51":"y","52":"_","53":"e","54":"p","55":"i","56":"c","57":"_","58":"b","59":"y","60":"_","61":"G","62":"e","63":"o","64":"r","65":"g","66":"e","67":"_","68":"R","69":".","70":"R","71":".","72":"_","73":"M","74":"a","75":"r","76":"t","77":"i","78":"n","79":".","80":">"},"token_ids":{"0":1,"1":72,"2":101,"3":105,"4":114,"5":105,"6":118,"7":125,"8":119,"9":36,"10":88,"11":101,"12":118,"13":107,"14":101,"15":118,"16":125,"17":105,"18":114,"19":36,"20":109,"21":119,"22":36,"23":109,"24":114,"25":36,"26":75,"27":101,"28":113,"29":105,"30":36,"31":115,"32":106,"33":36,"34":88,"35":108,"36":118,"37":115,"38":114,"39":105,"40":119,"41":48,"42":36,"43":101,"44":36,"45":106,"46":101,"47":114,"48":120,"49":101,"50":119,"51":125,"52":36,"53":105,"54":116,"55":109,"56":103,"57":36,"58":102,"59":125,"60":36,"61":75,"62":105,"63":115,"64":118,"65":107,"66":105,"67":36,"68":86,"69":50,"70":86,"71":50,"72":36,"73":81,"74":101,"75":118,"76":120,"77":109,"78":114,"79":50,"80":2},"entropies":{"0":3.3949158192,"1":2.1656451225,"2":2.3216569424,"3":2.8214058876,"4":1.5249242783,"5":0.0401624143,"6":0.0981037766,"7":0.0544578359,"8":0.3430138826,"9":1.0546212196,"10":0.25252828,"11":0.1494535804,"12":0.0624754503,"13":0.001355894,"14":0.0050173439,"15":0.0052358187,"16":0.0011725067,"17":0.0010307421,"18":1.0241208076,"19":3.6867966652,"20":0.4502205253,"21":0.0484119244,"22":2.2572875023,"23":0.3789347112,"24":1.0042934418,"25":2.9090054035,"26":1.8933598995,"27":1.3859074116,"28":0.3827198744,"29":0.2646365762,"30":1.7742085457,"31":0.0136727821,"32":0.0053820172,"33":0.5485631227,"34":0.2064044327,"35":0.0049266233,"36":0.0005439016,"37":0.0007023578,"38":0.0004170335,"39":0.0054524317,"40":1.1938130856,"41":0.0238215197,"42":3.1279797554,"43":1.3883389235,"44":3.0503094196,"45":1.695879817,"46":1.8551058769,"47":1.4570231438,"48":0.0047810897,"49":0.026396824,"50":0.6633765101,"51":0.3141393065,"52":2.8411159515,"53":1.143143177,"54":0.0520330966,"55":0.3398066461,"56":0.4140175879,"57":2.5563707352,"58":1.3370712996,"59":0.0227173548,"60":3.4447185993,"61":1.8576486111,"62":0.8189754486,"63":0.6776530743,"64":0.0677763447,"65":0.212713033,"66":0.1003480032,"67":0.1746164262,"68":0.4123829603,"69":0.5507118702,"70":0.1047425047,"71":0.0194335245,"72":0.001482119,"73":0.0009310447,"74":0.0002176317,"75":0.0076908777,"76":0.0003866984,"77":0.0008008487,"78":1.2395234108,"79":0.4564163089,"80":0.0000461392},"patch":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":5,"7":5,"8":5,"9":5,"10":5,"11":5,"12":5,"13":5,"14":5,"15":5,"16":5,"17":5,"18":5,"19":5,"20":6,"21":6,"22":6,"23":7,"24":7,"25":7,"26":8,"27":9,"28":10,"29":10,"30":10,"31":11,"32":11,"33":11,"34":11,"35":11,"36":11,"37":11,"38":11,"39":11,"40":11,"41":11,"42":11,"43":12,"44":13,"45":14,"46":15,"47":16,"48":17,"49":17,"50":17,"51":17,"52":17,"53":18,"54":18,"55":18,"56":18,"57":18,"58":19,"59":20,"60":20,"61":21,"62":22,"63":22,"64":22,"65":22,"66":22,"67":22,"68":22,"69":22,"70":22,"71":22,"72":22,"73":22,"74":22,"75":22,"76":22,"77":22,"78":22,"79":22,"80":22},"start":{"0":1,"1":1,"2":1,"3":1,"4":1,"5":1,"6":0,"7":0,"8":0,"9":0,"10":0,"11":0,"12":0,"13":0,"14":0,"15":0,"16":0,"17":0,"18":0,"19":0,"20":1,"21":0,"22":0,"23":1,"24":0,"25":0,"26":1,"27":1,"28":1,"29":0,"30":0,"31":1,"32":0,"33":0,"34":0,"35":0,"36":0,"37":0,"38":0,"39":0,"40":0,"41":0,"42":0,"43":1,"44":1,"45":1,"46":1,"47":1,"48":1,"49":0,"50":0,"51":0,"52":0,"53":1,"54":0,"55":0,"56":0,"57":0,"58":1,"59":1,"60":0,"61":1,"62":1,"63":0,"64":0,"65":0,"66":0,"67":0,"68":0,"69":0,"70":0,"71":0,"72":0,"73":0,"74":0,"75":0,"76":0,"77":0,"78":0,"79":0,"80":0}}
\ No newline at end of file
diff --git a/plot_data/entropy_figure.json b/plot_data/entropy_figure.json
new file mode 100644
index 0000000..79e42d9
--- /dev/null
+++ b/plot_data/entropy_figure.json
@@ -0,0 +1 @@
+{"text":"Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.","threshold":1.335442066192627,"dataframe_json":"{\"position\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":6,\"7\":7,\"8\":8,\"9\":9,\"10\":10,\"11\":11,\"12\":12,\"13\":13,\"14\":14,\"15\":15,\"16\":16,\"17\":17,\"18\":18,\"19\":19,\"20\":20,\"21\":21,\"22\":22,\"23\":23,\"24\":24,\"25\":25,\"26\":26,\"27\":27,\"28\":28,\"29\":29,\"30\":30,\"31\":31,\"32\":32,\"33\":33,\"34\":34,\"35\":35,\"36\":36,\"37\":37,\"38\":38,\"39\":39,\"40\":40,\"41\":41,\"42\":42,\"43\":43,\"44\":44,\"45\":45,\"46\":46,\"47\":47,\"48\":48,\"49\":49,\"50\":50,\"51\":51,\"52\":52,\"53\":53,\"54\":54,\"55\":55,\"56\":56,\"57\":57,\"58\":58,\"59\":59,\"60\":60,\"61\":61,\"62\":62,\"63\":63,\"64\":64,\"65\":65,\"66\":66,\"67\":67,\"68\":68,\"69\":69,\"70\":70,\"71\":71,\"72\":72,\"73\":73,\"74\":74,\"75\":75,\"76\":76,\"77\":77,\"78\":78,\"79\":79,\"80\":80},\"tokens\":{\"0\":\"<\",\"1\":\"D\",\"2\":\"a\",\"3\":\"e\",\"4\":\"n\",\"5\":\"e\",\"6\":\"r\",\"7\":\"y\",\"8\":\"s\",\"9\":\"_\",\"10\":\"T\",\"11\":\"a\",\"12\":\"r\",\"13\":\"g\",\"14\":\"a\",\"15\":\"r\",\"16\":\"y\",\"17\":\"e\",\"18\":\"n\",\"19\":\"_\",\"20\":\"i\",\"21\":\"s\",\"22\":\"_\",\"23\":\"i\",\"24\":\"n\",\"25\":\"_\",\"26\":\"G\",\"27\":\"a\",\"28\":\"m\",\"29\":\"e\",\"30\":\"_\",\"31\":\"o\",\"32\":\"f\",\"33\":\"_\",\"34\":\"T\",\"35\":\"h\",\"36\":\"r\",\"37\":\"o\",\"38\":\"n\",\"39\":\"e\",\"40\":\"s\",\"41\":\",\",\"42\":\"_\",\"43\":\"a\",\"44\":\"_\",\"45\":\"f\",\"46\":\"a\",\"47\":\"n\",\"48\":\"t\",\"49\":\"a\",\"50\":\"s\",\"51\":\"y\",\"52\":\"_\",\"53\":\"e\",\"54\":\"p\",\"55\":\"i\",\"56\":\"c\",\"57\":\"_\",\"58\":\"b\",\"59\":\"y\",\"60\":\"_\",\"61\":\"G\",\"62\":\"e\",\"63\":\"o\",\"64\":\"r\",\"65\":\"g\",\"66\":\"e\",\"67\":\"_\",\"68\":\"R\",\"69\":\".\",\"70\":\"R\",\"71\":\".\",\"72\":\"_\",\"73\":\"M\",\"74\":\"a\",\"75\":\"r\",\"76\":\"t\",\"77\":\"i\",\"78\":\"n\",\"79\":\".\",\"80\":\">\"},\"token_ids\":{\"0\":1,\"1\":72,\"2\":101,\"3\":105,\"4\":114,\"5\":105,\"6\":118,\"7\":125,\"8\":119,\"9\":36,\"10\":88,\"11\":101,\"12\":118,\"13\":107,\"14\":101,\"15\":118,\"16\":125,\"17\":105,\"18\":114,\"19\":36,\"20\":109,\"21\":119,\"22\":36,\"23\":109,\"24\":114,\"25\":36,\"26\":75,\"27\":101,\"28\":113,\"29\":105,\"30\":36,\"31\":115,\"32\":106,\"33\":36,\"34\":88,\"35\":108,\"36\":118,\"37\":115,\"38\":114,\"39\":105,\"40\":119,\"41\":48,\"42\":36,\"43\":101,\"44\":36,\"45\":106,\"46\":101,\"47\":114,\"48\":120,\"49\":101,\"50\":119,\"51\":125,\"52\":36,\"53\":105,\"54\":116,\"55\":109,\"56\":103,\"57\":36,\"58\":102,\"59\":125,\"60\":36,\"61\":75,\"62\":105,\"63\":115,\"64\":118,\"65\":107,\"66\":105,\"67\":36,\"68\":86,\"69\":50,\"70\":86,\"71\":50,\"72\":36,\"73\":81,\"74\":101,\"75\":118,\"76\":120,\"77\":109,\"78\":114,\"79\":50,\"80\":2},\"entropies\":{\"0\":3.3949158192,\"1\":2.1656451225,\"2\":2.3216569424,\"3\":2.8214058876,\"4\":1.5249242783,\"5\":0.0401624143,\"6\":0.0981037766,\"7\":0.0544578359,\"8\":0.3430138826,\"9\":1.0546212196,\"10\":0.25252828,\"11\":0.1494535804,\"12\":0.0624754503,\"13\":0.001355894,\"14\":0.0050173439,\"15\":0.0052358187,\"16\":0.0011725067,\"17\":0.0010307421,\"18\":1.0241208076,\"19\":3.6867966652,\"20\":0.4502205253,\"21\":0.0484119244,\"22\":2.2572875023,\"23\":0.3789347112,\"24\":1.0042934418,\"25\":2.9090054035,\"26\":1.8933598995,\"27\":1.3859074116,\"28\":0.3827198744,\"29\":0.2646365762,\"30\":1.7742085457,\"31\":0.0136727821,\"32\":0.0053820172,\"33\":0.5485631227,\"34\":0.2064044327,\"35\":0.0049266233,\"36\":0.0005439016,\"37\":0.0007023578,\"38\":0.0004170335,\"39\":0.0054524317,\"40\":1.1938130856,\"41\":0.0238215197,\"42\":3.1279797554,\"43\":1.3883389235,\"44\":3.0503094196,\"45\":1.695879817,\"46\":1.8551058769,\"47\":1.4570231438,\"48\":0.0047810897,\"49\":0.026396824,\"50\":0.6633765101,\"51\":0.3141393065,\"52\":2.8411159515,\"53\":1.143143177,\"54\":0.0520330966,\"55\":0.3398066461,\"56\":0.4140175879,\"57\":2.5563707352,\"58\":1.3370712996,\"59\":0.0227173548,\"60\":3.4447185993,\"61\":1.8576486111,\"62\":0.8189754486,\"63\":0.6776530743,\"64\":0.0677763447,\"65\":0.212713033,\"66\":0.1003480032,\"67\":0.1746164262,\"68\":0.4123829603,\"69\":0.5507118702,\"70\":0.1047425047,\"71\":0.0194335245,\"72\":0.001482119,\"73\":0.0009310447,\"74\":0.0002176317,\"75\":0.0076908777,\"76\":0.0003866984,\"77\":0.0008008487,\"78\":1.2395234108,\"79\":0.4564163089,\"80\":0.0000461392},\"patch\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":5,\"7\":5,\"8\":5,\"9\":5,\"10\":5,\"11\":5,\"12\":5,\"13\":5,\"14\":5,\"15\":5,\"16\":5,\"17\":5,\"18\":5,\"19\":5,\"20\":6,\"21\":6,\"22\":6,\"23\":7,\"24\":7,\"25\":7,\"26\":8,\"27\":9,\"28\":10,\"29\":10,\"30\":10,\"31\":11,\"32\":11,\"33\":11,\"34\":11,\"35\":11,\"36\":11,\"37\":11,\"38\":11,\"39\":11,\"40\":11,\"41\":11,\"42\":11,\"43\":12,\"44\":13,\"45\":14,\"46\":15,\"47\":16,\"48\":17,\"49\":17,\"50\":17,\"51\":17,\"52\":17,\"53\":18,\"54\":18,\"55\":18,\"56\":18,\"57\":18,\"58\":19,\"59\":20,\"60\":20,\"61\":21,\"62\":22,\"63\":22,\"64\":22,\"65\":22,\"66\":22,\"67\":22,\"68\":22,\"69\":22,\"70\":22,\"71\":22,\"72\":22,\"73\":22,\"74\":22,\"75\":22,\"76\":22,\"77\":22,\"78\":22,\"79\":22,\"80\":22},\"start\":{\"0\":1,\"1\":1,\"2\":1,\"3\":1,\"4\":1,\"5\":1,\"6\":0,\"7\":0,\"8\":0,\"9\":0,\"10\":0,\"11\":0,\"12\":0,\"13\":0,\"14\":0,\"15\":0,\"16\":0,\"17\":0,\"18\":0,\"19\":0,\"20\":1,\"21\":0,\"22\":0,\"23\":1,\"24\":0,\"25\":0,\"26\":1,\"27\":1,\"28\":1,\"29\":0,\"30\":0,\"31\":1,\"32\":0,\"33\":0,\"34\":0,\"35\":0,\"36\":0,\"37\":0,\"38\":0,\"39\":0,\"40\":0,\"41\":0,\"42\":0,\"43\":1,\"44\":1,\"45\":1,\"46\":1,\"47\":1,\"48\":1,\"49\":0,\"50\":0,\"51\":0,\"52\":0,\"53\":1,\"54\":0,\"55\":0,\"56\":0,\"57\":0,\"58\":1,\"59\":1,\"60\":0,\"61\":1,\"62\":1,\"63\":0,\"64\":0,\"65\":0,\"66\":0,\"67\":0,\"68\":0,\"69\":0,\"70\":0,\"71\":0,\"72\":0,\"73\":0,\"74\":0,\"75\":0,\"76\":0,\"77\":0,\"78\":0,\"79\":0,\"80\":0}}"}
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..e2ecd0d
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,5 @@
+[tool.isort]
+profile = "black"
+known_bytelatent = "bytelatent"
+known_apps = "apps"
+sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..c6d87f1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+numpy
+omegaconf
+msgspec
+rouge-score
+sacrebleu
+sentencepiece
+tiktoken
+fsspec
+blobfile
+wandb
+viztracer
+lm-eval
+scipy
+pynvml
+datatrove
+orjson
+luigi
+pydantic
+altair
+submitit
+typer
+rich
diff --git a/setup/create_env.sh b/setup/create_env.sh
new file mode 100644
index 0000000..9c7dad9
--- /dev/null
+++ b/setup/create_env.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+#SBATCH --job-name=env_creation
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+#SBATCH --gres=gpu:8
+#SBATCH --exclusive
+#SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=128
+#SBATCH --mem=0
+#SBATCH --time=01:00:00
+
+# Exit immediately if a command exits with a non-zero status
+set -e
+
+# Start timer
+start_time=$(date +%s)
+
+# Get the current date
+current_date=$(date +%y%m%d)
+
+# Create environment name with the current date
+env_prefix=blt_$current_date
+
+# Create the conda environment
+
+source $CONDA_ROOT/etc/profile.d/conda.sh
+conda create -n $env_prefix python=3.11 -y -c anaconda
+conda activate $env_prefix
+
+echo "Currently in env $(which python)"
+
+# Install packages
+pip install torch==2.5.0 xformers --index-url https://download.pytorch.org/whl/cu121
+pip install ninja
+pip install --requirement requirements.txt
+
+# End timer
+end_time=$(date +%s)
+
+# Calculate elapsed time in seconds
+elapsed_time=$((end_time - start_time))
+
+# Convert elapsed time to minutes
+elapsed_minutes=$((elapsed_time / 60))
+
+echo "Environment $env_prefix created and all packages installed successfully in $elapsed_minutes minutes!"
diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py
new file mode 100644
index 0000000..1aacf8f
--- /dev/null
+++ b/setup/download_prepare_hf_data.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import argparse
+import os
+import subprocess
+import time
+
+import requests
+from huggingface_hub import snapshot_download
+
+
+def run_command(command):
+    print(f"Running: {command}")
+    subprocess.run(command, shell=True, check=True)
+
+
+def download_dataset(repo_id, local_dir, allow_patterns):
+    print(f"Downloading dataset from {repo_id}...")
+    max_retries = 5
+    retry_delay = 10  # seconds
+    for attempt in range(max_retries):
+        try:
+            snapshot_download(
+                repo_id,
+                repo_type="dataset",
+                local_dir=local_dir,
+                allow_patterns=allow_patterns,
+                resume_download=True,
+                max_workers=16,  # Don't hesitate to increase this number to lower the download time
+            )
+            break
+        except requests.exceptions.ReadTimeout:
+            if attempt < max_retries - 1:
+                print(f"Timeout occurred. Retrying in {retry_delay} seconds...")
+                time.sleep(retry_delay)
+            else:
+                raise
+    print(f"Dataset downloaded to {local_dir}")
+
+
+def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
+    from datatrove.executor import LocalPipelineExecutor
+    from datatrove.pipeline.readers import ParquetReader
+    from datatrove.pipeline.writers import JsonlWriter
+
+    pipeline_exec = LocalPipelineExecutor(
+        pipeline=[
+            ParquetReader(
+                src_dir,
+                file_progress=True,
+                doc_progress=True,
+                glob_pattern="**/*.parquet",
+            ),
+            JsonlWriter(
+                tgt_dir,
+                output_filename=dataset + ".chunk.${rank}.jsonl",
+                compression=None,
+            ),
+        ],
+        tasks=ntasks,
+        logging_dir=os.path.join(work_dir, "datatrove"),
+    )
+    pipeline_exec.run()
+
+
+def setup_terashuf(work_dir):
+    terashuf_dir = os.path.join(work_dir, "terashuf")
+    terashuf_executable = os.path.join(terashuf_dir, "terashuf")
+
+    if os.path.exists(terashuf_executable):
+        print("terashuf executable already exists. Skipping setup.")
+        return terashuf_dir
+
+    print("Setting up terashuf...")
+    run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}")
+    run_command(f"make -C {terashuf_dir}")
+    return terashuf_dir
+
+
+def main(dataset, memory, data_dir, seed=42, nchunks=32):
+    # Configuration
+    repo_id = {
+        "fineweb_edu": "HuggingFaceFW/fineweb-edu",
+        "fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu",
+        "dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0",
+        "dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0",
+    }[dataset]
+    src_dir = f"{data_dir}/{dataset}"
+    out_dir = f"{src_dir}_shuffled"
+    os.makedirs(out_dir, exist_ok=True)
+    work_dir = src_dir  # Directory of this Python file
+    prefix = f"{dataset}.chunk."
+    orig_extension = {
+        "fineweb_edu": ".jsonl",
+        "fineweb_edu_10bt": ".jsonl",
+        "dclm_baseline_1.0": ".jsonl.zst",
+        "dclm_baseline_1.0_10prct": ".jsonl.zst",
+    }[dataset]
+    cat_command = {
+        "fineweb_edu": "cat",
+        "fineweb_edu_10bt": "cat",
+        "dclm_baseline_1.0": "zstdcat",
+        "dclm_baseline_1.0_10prct": "zstdcat",
+    }[dataset]
+    allow_patterns = {
+        "fineweb_edu": None,
+        "fineweb_edu_10bt": "sample/10BT/*",
+        "dclm_baseline_1.0": "*.jsonl.zst",
+        "dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst",
+    }[dataset]
+    suffix = ".jsonl"
+    k_validation = 10000  # Number of lines to take from each chunk for validation
+
+    # Setup terashuf
+    terashuf_dir = setup_terashuf(work_dir)
+
+    # Download dataset
+    download_dataset(repo_id, src_dir, allow_patterns)
+
+    if "fineweb" in dataset:
+        parquet_to_jsonl(dataset, work_dir, src_dir, src_dir)
+
+    # Set up environment variables
+    os.environ["MEMORY"] = f"{memory}"
+    os.environ["SEED"] = f"{seed}"
+
+    # Run the original shuffling and splitting command
+    terashuf_executable = os.path.join(terashuf_dir, "terashuf")
+    run_command(
+        f"ulimit -n 100000 && "
+        f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | "
+        f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}"
+        "; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;"
+    )
+
+    # Create validation set and remove lines from chunks
+    validation_file = f"{out_dir}/{dataset}.val{suffix}"
+    for i in range(nchunks):
+        chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}"
+        run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}")
+        run_command(f"sed -i '1,{k_validation}d' {chunk_file}")
+
+    print("All tasks completed successfully!")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("dataset", type=str)
+    parser.add_argument("memory", type=float, default=8)
+    parser.add_argument("--data_dir", type=str, default="data")
+    parser.add_argument("--seed", type=int, default=42)
+    parser.add_argument("--nchunks", type=int, default=32)
+
+    args = parser.parse_args()
+
+    main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks)
diff --git a/setup/download_tokenizer.py b/setup/download_tokenizer.py
new file mode 100644
index 0000000..fc9c8d5
--- /dev/null
+++ b/setup/download_tokenizer.py
@@ -0,0 +1,60 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import argparse
+import os
+from typing import Optional
+
+from requests.exceptions import HTTPError
+
+TOKENIZER = {
+    "llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"),
+    "llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"),
+    "gemma": ("google/gemma-2-9b", "tokenizer.model"),
+}
+
+
+def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None):
+    if tokenizer_name in TOKENIZER:
+        repo_id, filename = TOKENIZER[tokenizer_name]
+
+        from huggingface_hub import hf_hub_download
+
+        try:
+            hf_hub_download(
+                repo_id=repo_id,
+                filename=filename,
+                local_dir=path_to_save,
+                local_dir_use_symlinks=False,
+                token=api_key if api_key else None,
+            )
+        except HTTPError as e:
+            if e.response.status_code == 401:
+                print(
+                    "You need to pass a valid `--hf_token=...` to download private checkpoints."
+                )
+            else:
+                raise e
+    else:
+        from tiktoken import get_encoding
+
+        if "TIKTOKEN_CACHE_DIR" not in os.environ:
+            os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save
+        try:
+            get_encoding(tokenizer_name)
+        except ValueError:
+            print(
+                f"Tokenizer {tokenizer_name} not found. Please check the name and try again."
+            )
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("tokenizer_name", type=str)
+    parser.add_argument("tokenizer_dir", type=str, default=8)
+    parser.add_argument("--api_key", type=str, default="")
+    args = parser.parse_args()
+
+    main(
+        tokenizer_name=args.tokenizer_name,
+        path_to_save=args.tokenizer_dir,
+        api_key=args.api_key,
+    )