commit bcc039bb75aec70385fd5c148497d4ae86b526b5 Author: Pedro Rodriguez Date: Thu Dec 12 15:32:30 2024 -0800 Initial commit diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..433932f --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,12 @@ +name: Lint with Black + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: + version: "24.8.0" diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml new file mode 100644 index 0000000..16f1abf --- /dev/null +++ b/.github/workflows/isort.yml @@ -0,0 +1,10 @@ +name: Lint with isort + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: isort/isort-action@master diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..56891a9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,168 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +*.out + +figures/ +.vscode/ +.DS_Store + diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..ae6f541 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,8 @@ +{ + "overrides": [ + { + "files": "*.yaml", + "options": { "tabWidth": 2 } + } + ] +} diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..3232ed6 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..88d3ab0 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,36 @@ +# Contributing to + +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests + +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") + +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues + +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License + +By contributing to BLT, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bc559a9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright 2024 Meta + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice,this list +of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this +list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may +be used to endorse or promote products derived from this software without specific +prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT +SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..689c8e4 --- /dev/null +++ b/README.md @@ -0,0 +1,117 @@ +# Byte Latent Transformer + +This repository contains code for our paper: "Byte Latent Transformer: Patches Scale Better Than Tokens" + +- [Paper Link](https://dl.fbaipublicfiles.com/blt/BLT__Patches_Scale_Better_Than_Tokens.pdf) + +## Abstract + +We introduce the Byte Latent Transformer architecture (BLTs), a new byte-level LLM architecture that +for the first time, matches tokenization-based LLM performance at scale, with significant improvements +in inference efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve +as the primary units of computation. Patches are segmented dynamically based on the entropy of the +next byte, allocating more compute and model capacity where there is more data complexity. The BLT +architecture includes new attention mechanisms to maximize the information flow between byte and +patch hidden representations and a new type of byte-sequence memory. We present the first scaling +study of byte-level models up to 8B parameters and 8T training bytes, showing for the first time +that we can train a model end-to-end at scale from bytes with no tokenization or other preprocessing. +Scaling trends reveal training and inference efficiency benefits from dynamically selecting very long +patches on average, along with qualitative improvements with reasoning and long tail generalization +from modeling byte-sequences. + +![BLT Architecture Diagram](blt-figure.jpg) + +## Development Status + +We are actively updating the blt code to make it easier to reproduce our results. +Please file an issue and/or be patient while we make more of our code public! + +## Quick start + +The following commands launch a SLURM job that creates an environment for Meta Lingua. +The env creation should take around 5 minutes without counting downloads. + +```bash +git clone https://github.com/facebookresearch/blt +cd blt + +bash setup/create_env.sh +# or if you have access to a SLURM cluster +sbatch setup/create_env.sh +``` + +Once that is done your can activate the environment + +```bash +conda activate blt_ +``` + +use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`). +This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details. + +```bash +python setup/download_prepare_hf_data.py fineweb_edu --data_dir ./data --seed 42 --nchunks +``` + +to download tokenizer (here llama3), use the folowing script: + +```bash +python setup/download_tokenizer.py llama3 --api_key +``` + +Now launch a debug job to check if everything works. **The provided configurations are templates, you need to adapt them for them to work (change `dump_dir`, `data.root_dir`, `data.tokenizer.path`, etc ...)** + +```bash +# stool stands for SLURM tool ! +python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition= +# if you want to launch locally you can use torchrun +torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml +# or you can also launch on 1 GPU +python -m bytelatent.train config=apps/bytelatent/configs/debug.yaml +``` + +When using `stool`, if a job crashes, it can be relaunched using sbatch: + +```bash +sbatch path/to/dump_dir/submit.slurm +``` + +## Linting + +To lint, run the following command + +``` +bash dev/lint.sh +``` + +## Citation + +The BLT is partially based on Meta Lingua, so consider citing it in addition to our BLT paper if you re-use our work. + +BLT Paper Citation (will be updated to arXiv soon) + +``` +@article{meta_blt, + author = {Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer}, + title = {Byte Latent Transformer: Patches Scale Better Than Tokens}, + url = {https://github.com/facebookresearch/blt}, + year = {2024} +} +``` + +Lingua Code + +``` +@misc{meta_lingua, + author = {Mathurin Videau, Badr Youbi Idrissi, Daniel Haziza, Luca Wehrstedt, Jade Copet, Olivier Teytaud, David Lopez-Paz}, + title = {{Meta Lingua}: A minimal {PyTorch LLM} training library}, + url = {https://github.com/facebookresearch/lingua}, + year = {2024} +} +``` + +## License + +The BLT code is partially based on Meta Lingia. + +Meta Lingua is licensed under BSD-3-Clause license. Refer to the LICENSE file in the top level directory. diff --git a/apps/__init__.py b/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/main/__init__.py b/apps/main/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/main/configs/eval.yaml b/apps/main/configs/eval.yaml new file mode 100644 index 0000000..4b52ba0 --- /dev/null +++ b/apps/main/configs/eval.yaml @@ -0,0 +1,35 @@ +name: "debug_evals" +# ckpt_dir: !!CHANGETHIS!! +# dump_dir: !!CHANGETHIS!! +generator: + max_tokens: 8192 + dtype: bf16 + temperature: 1.0 + top_p: 0.95 +harness: + tasks: + - hellaswag + - task: boolq + dataset_kwargs: + trust_remote_code: true + - task: nq_open + num_fewshot: 5 + - piqa + - task: social_iqa + dataset_kwargs: + trust_remote_code: true + - triviaqa + - winogrande + - openbookqa + - arc_easy + - arc_challenge + - race + - commonsense_qa + # - coqa + - copa + - gsm8k + - bbh + - mmlu + - mmlu_pro +validation: + max_steps: 1000 diff --git a/apps/main/configs/llama_1B.yaml b/apps/main/configs/llama_1B.yaml new file mode 100644 index 0000000..786d848 --- /dev/null +++ b/apps/main/configs/llama_1B.yaml @@ -0,0 +1,87 @@ +# dump_dir: !!!CHANGE_THIS!!! +name: large_lm +steps: 60_000 +probe_freq: null +seed: 777 + +optim: + lr: 3e-3 + weight_decay: 0.033 + warmup: 5000 + lr_min_ratio: 0.000001 + clip: 1.0 + +distributed: + fsdp_type: full_shard + compile: true + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +model: + dim: 2048 + n_layers: 25 + n_heads: 16 + +data: + root_dir: data/shuffled + sources: + dclm_baseline_1.0: 100.0 + batch_size: 4 + prefetch_size: 1024 + seq_len: 4096 + n_views: 2 + load_async: true + add_bos: true + add_eos: true + tokenizer: + name: tiktoken + path: tokenizers/cl_toplang_128k.tiktoken + +profiling: + run: true + mem_warmup: 0 + mem_steps: 4 + profile_warmup: 100 + profile_steps: 4 + +checkpoint: + dump: + every: 2500 + keep: 3 + eval: + every: 5000 + keep: -1 + +logging: + freq: 1 + +async_eval_gpus: 8 +eval: + harness: + tasks: + - hellaswag + - task: boolq + dataset_kwargs: + trust_remote_code: true + - piqa + - task: social_iqa + dataset_kwargs: + trust_remote_code: true + - winogrande + - openbookqa + - arc_easy + - arc_challenge + - race + - commonsense_qa + - copa + # - coqa + # - task: nq_open + # num_fewshot: 5 + # - triviaqa + validation: + max_steps: 1000 + generator: + max_tokens: 16384 + dtype: bf16 diff --git a/apps/main/configs/llama_7B.yaml b/apps/main/configs/llama_7B.yaml new file mode 100644 index 0000000..4461dd9 --- /dev/null +++ b/apps/main/configs/llama_7B.yaml @@ -0,0 +1,95 @@ +#python -m lingua.stool config=apps/main/configs/llama2_7B.yaml nodes=32 account=fair_amaia_cw_codegen qos=lowest +# dump_dir: !!!CHANGE_THIS!!! +name: "7b_baseline" +steps: 100_000 +grad_acc_steps: 1 +probe_freq: 100 + +seed: 777 +optim: + lr: 1.0e-3 + weight_decay: 0.1 + warmup: 2000 + lr_min_ratio: 0.000001 + clip: 1.0 + +distributed: + fsdp_type: full_shard + compile: true + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +model: + dim: 4096 + n_layers: 32 + n_heads: 32 + rope_theta: 100_000 + ffn_dim_multiplier: 1.0 + multiple_of: 256 + +data: + root_dir: data/shuffled + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 1024 + seq_len: 4096 + n_views: 2 + load_async: true + tokenizer: + name: tiktoken + path: tokenizers/cl_toplang_128k.tiktoken + +profiling: + run: true + mem_warmup: 0 + mem_steps: 4 + profile_warmup: 100 + profile_steps: 4 + +checkpoint: + dump: + every: 10000 + keep: -1 + eval: + every: 1000 + keep: 3 + +logging: + freq: 1 + +async_eval_gpus: 8 +eval: + dataset_dir: datasets/eval + harness: + tasks: + - hellaswag + - task: boolq + dataset_kwargs: + trust_remote_code: true + - piqa + - task: social_iqa + dataset_kwargs: + trust_remote_code: true + - winogrande + - openbookqa + - arc_easy + - arc_challenge + - race + - commonsense_qa + # - coqa + - copa + - mmlu + - mmlu_pro + # - task: nq_open + # num_fewshot: 5 + # - triviaqa + # - gsm8k + # - bbh + validation: + max_steps: 1000 + generator: + max_tokens: 8192 + dtype: bf16 diff --git a/apps/main/eval.py b/apps/main/eval.py new file mode 100644 index 0000000..ed20f49 --- /dev/null +++ b/apps/main/eval.py @@ -0,0 +1,354 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +import logging +import os +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import torch +from lingua.args import dump_config +from lingua.data import init_choice_state, setup_sources +from lm_eval import simple_evaluate +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from omegaconf import OmegaConf + +from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.distributed import ( + DistributedArgs, + dist_mean_dict, + get_global_rank, + get_world_size, + setup_torch_distributed, +) +from bytelatent.transformer import LMTransformer, LMTransformerArgs + +from apps.main.generate import ( + PackedCausalTransformerGenerator, + PackedCausalTransformerGeneratorArgs, + load_consolidated_model_and_tokenizer, +) + +EVAL_FOLDER_NAME = "{:010d}" + +logger = logging.getLogger() + + +@dataclass +class LMHarnessArgs: + tasks: Optional[List[Any]] = None + num_fewshot: Optional[int] = None + device: Optional[str] = None + use_cache: Optional[str] = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: Optional[Union[int, float]] = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: Optional[str] = None + apply_chat_template: Union[bool, str] = False + fewshot_as_multiturn: bool = False + gen_kwargs: Optional[str] = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +@dataclass +class ValidationArgs: + max_steps: Optional[int] = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) + use_val_from_train_src: bool = True # Use the validation set from training sources + root_dir: str = "" + sources: List[str] = field(default_factory=list) # Other sources to eval on + + +@dataclass +class EvalArgs: + name: str = "evals" + dump_dir: Optional[str] = None + metric_log_dir: Optional[str] = None + ckpt_dir: str = "" + generator: PackedCausalTransformerGeneratorArgs = field( + default_factory=PackedCausalTransformerGeneratorArgs + ) + harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) + validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs) + + wandb: Optional[Any] = None + + global_step: Optional[int] = None # for in-training evaluation + + +def all_dicts_same(dict_list): + if not dict_list: # Check if the list is empty + return True + + # Compare each dictionary to the first one + first_dict = dict_list[0] + return all(d == first_dict for d in dict_list) + + +class MockAccelerator: + def gather(self, tensor): + l = [torch.zeros_like(tensor) for _ in range(get_world_size())] + torch.distributed.all_gather(l, tensor) + return torch.stack(l) + + def wait_for_everyone(self): + torch.distributed.barrier() + + +# Light wrapper around generator for lm-eval harness +class EvalHarnessLM(LM): + def __init__(self, generator): + super().__init__() + self.generator = generator + self.accelerator = MockAccelerator() + self._rank = get_global_rank() + self._world_size = get_world_size() + self.device = generator.device + + def generate_until(self, requests: List[Instance]) -> List[str]: + prompts, gen_args = zip(*[req.args for req in requests]) + assert all_dicts_same(gen_args), "Doesn't support different gen args for now" + gen_args = gen_args[0] + temperature = gen_args.get("temperature", 0.0) + top_p = gen_args.get("top_p", None) + top_k = gen_args.get("top_k", None) + until = gen_args.get("until", []) + + self.generator.temperature = temperature + self.generator.top_p = top_p + self.generator.top_k = top_k + self.generator.until = until + generations, _, _ = self.generator.generate(prompts) + filtered_gen = [] + for g in generations: + for e in until: + g = g.replace(e, "") + filtered_gen.append(g) + return filtered_gen + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + prompts, continuations = zip(*[req.args for req in requests]) + inputs = [req.args[0] + req.args[1] for req in requests] + max_gen_len = self.generator.max_gen_len + # We temporarily lower max gen len + self.generator.max_gen_len = 1 + _, lls, greedy = self.generator.generate(inputs) + results = [] + for p, ll, gr in zip(prompts, lls, greedy): + p_len = len( + self.generator.tokenizer.encode(p, add_bos=False, add_eos=False) + ) + results.append((ll[p_len:].sum().item(), gr[p_len:].all().item())) + + self.generator.max_gen_len = max_gen_len + return results + + def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + prompts = [req.args[0] for req in requests] + max_gen_len = self.generator.max_gen_len + # We temporarily lower max gen len + self.generator.max_gen_len = 1 + _, lls, _ = self.generator.generate(prompts) + results = [] + for ll in lls: + results.append((ll.sum().item(),)) + self.generator.max_gen_len = max_gen_len + + return results + + +def eval_on_val(generator, val_args: ValidationArgs, train_cfg): + srcs = {} + for src in val_args.sources: + path = os.path.join(val_args.root_dir, src) + srcs[path] = 1.0 + for src in train_cfg.data.sources: + path = os.path.join(train_cfg.data.root_dir, src) + srcs[path] = 1.0 + + multi_state = init_choice_state( + "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" + ) + path_to_iter = setup_sources(multi_state) + + max_gen_len = generator.max_gen_len + # We temporarily lower max gen len + generator.max_gen_len = 1 + + all_val_metrics = {} + for src in path_to_iter: + jsonl_iterator = path_to_iter[src] + texts = [] + logger.info(f"Running validation on {src}...") + for step, (content, state) in enumerate(jsonl_iterator): + if state["current_iter"] > 0 or ( + val_args.max_steps is not None and step >= val_args.max_steps + ): + break + content_key = "text" if ("text" in content) else "content" + texts.append(content[content_key]) + + _, loglikelihood, _ = generator.generate(texts) + + metrics = defaultdict(list) + for i, ll in enumerate(loglikelihood): + tmp = ll.sum().item() + metrics["nll"].append(tmp) + metrics["nll_per_token"].append(tmp / len(ll)) + metrics["nll_per_char"].append(tmp / len(texts[i])) + + metrics["avg_seqlen"].append(len(ll)) + + for m in metrics: + metrics[m] = sum(metrics[m]) / len(metrics[m]) + metrics.update(dist_mean_dict(metrics)) + logger.info(f"Validation on {src} done. Metrics: {metrics}") + + name = os.path.basename(src) + if name in all_val_metrics: + logger.warning( + f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1" + ) + name = f"{name}_1" + all_val_metrics[name] = metrics + + generator.max_gen_len = max_gen_len + + return all_val_metrics + + +def launch_eval(cfg: EvalArgs): + if not torch.distributed.is_initialized(): + setup_torch_distributed(DistributedArgs()) + if ( + Path(cfg.ckpt_dir).exists() + and (Path(cfg.ckpt_dir) / "params.json").exists() + and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + ): + consolidate_path = Path(cfg.ckpt_dir) + else: + consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER + if not consolidate_path.exists() and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + + Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) + dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + + consolidate_path = str(consolidate_path) + torch.distributed.barrier() + logger.info("Loading model") + model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( + consolidate_path, + model_cls=LMTransformer, + model_args_cls=LMTransformerArgs, + ) + logger.info("Model loaded") + model.eval() + generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + + wrap = EvalHarnessLM(generator) + results = simple_evaluate(wrap, **asdict(cfg.harness)) + val_results = None + if cfg.validation: + val_results = eval_on_val(generator, cfg.validation, train_cfg) + if get_global_rank() == 0: + with open(Path(cfg.dump_dir) / "results.json", "w") as f: + f.write(json.dumps(results)) + logger.info(f"All evaluation results: {results['results']}") + if val_results is not None: + with open(Path(cfg.dump_dir) / "validation.json", "w") as f: + f.write(json.dumps(val_results)) + logger.info(f"All validation results: {val_results}") + if cfg.metric_log_dir and get_global_rank() == 0: + metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + + logger.info(f"Writing metric logs to {metric_log_path}") + timestamp = { + "created_at": datetime.utcnow().isoformat(), + } + if cfg.global_step is not None: + timestamp["global_step"] = cfg.global_step + print( + json.dumps(timestamp | results["results"]), + file=open(metric_log_path, mode="a"), + flush=True, + ) + + val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" + if val_results is not None: + print( + json.dumps(timestamp | val_results), + file=open(val_log_path, mode="a"), + flush=True, + ) + + del generator + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate EvalArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call eval.py with eval.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in EvalArgs dataclass. + """ + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.structured(EvalArgs()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_object(cfg) + launch_eval(cfg) + + +if __name__ == "__main__": + main() diff --git a/apps/main/generate.py b/apps/main/generate.py new file mode 100644 index 0000000..a3a8627 --- /dev/null +++ b/apps/main/generate.py @@ -0,0 +1,463 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + +import torch +from lingua.args import dataclass_from_dict +from lingua.tokenizers.abstract_tokenizer import Tokenizer +from lingua.tokenizers.build_tokenizer import build_tokenizer +from omegaconf import OmegaConf +from torch import nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import create_block_mask +from tqdm import tqdm + +from bytelatent.base_transformer import ( + Attention, + causal_mask, + generate_doc_mask_mod, + lengths_to_local_ids, + lengths_to_start_ids, +) +from bytelatent.checkpoint import CONSOLIDATE_NAME +from bytelatent.transformer import LMTransformer, LMTransformerArgs + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +def sample_top_k(probs, k): + topk_value, _ = torch.topk(probs, k) # batch_sz x topk + min_value_top_k = topk_value[:, [-1]] + probs[probs < min_value_top_k] = 0.0 + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs, num_samples=1) + return next_token + + +def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): + shape = logits.shape + logits = logits.flatten(end_dim=-2) + if temperature > 0.0: + probs = torch.softmax(logits / temperature, dim=-1) + + if top_p is not None: + next_token = sample_top_p(probs, top_p) + elif top_k is not None: + next_token = sample_top_k(probs, top_k) + else: + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1) + return next_token.view(shape[:-1]) + + +def pack_prompts(prompts: List[int]): + res = [] + lengths = [] + for i, p in enumerate(prompts): + p = torch.tensor(p, dtype=torch.long) + l = p.size(0) + res.append(p) + lengths.append(l) + lengths = torch.tensor(lengths, dtype=torch.long) + res = torch.cat(res) + return res, lengths + + +def batch_prompts(prompts, max_elements, lengths=None): + batches = [] + current_batch = [] + current_count = 0 + + for i in range(len(prompts)): + prt = prompts[i] + prompt_size = len(prt) if lengths is None else lengths[i] + if current_count + prompt_size <= max_elements: + current_batch.append(prt) + current_count += prompt_size + else: + if current_batch: # Add the current batch to batches + batches.append(current_batch) + # Start a new batch with the current prompt + current_batch = [prt] + current_count = prompt_size + + # Add the last batch if it contains any prompts + if current_batch: + batches.append(current_batch) + + return batches + + +class KVCache(nn.Module): + def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device): + super().__init__() + shape = (bsz, seqlen, n_heads, head_dim) + self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device)) + self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device)) + self.offset = 0 + + def reset(self): + self.k_cache.zero_() + self.v_cache.zero_() + self.offset = 0 + + def update(self, k_val, v_val, tok_idx): + # input_pos: [B], k_val: [B, S, H, D] + self.k_cache.index_copy_(1, self.offset + tok_idx, k_val) + self.v_cache.index_copy_(1, self.offset + tok_idx, v_val) + return self.k_cache, self.v_cache + + +@dataclass +class PackedCausalTransformerGeneratorArgs: + temperature: float = 0.0 + top_p: Optional[float] = None + top_k: Optional[float] = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: Optional[int] = None + until: List[str] = field(default_factory=list) + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: Optional[str] = "bf16" + device: Optional[str] = "cuda" + + +class PackedCausalTransformerGenerator: + def __init__( + self, + cfg: PackedCausalTransformerGeneratorArgs, + model: nn.Module, + tokenizer: Tokenizer, + ): + """ + This class wraps a causal transformer model with its corresponding tokenizer + and provides an efficient way to pack prompts together and do generation on + the packed sequence. + + For example, if we had the prompts "Hello, I am a " and "Initiating calibration " + Then this class will concatenate those sequence (pack them together) + "Hello, I am a Initiating calibration" + And make the necessary attention masks such that a sequence only attends to itself + during prefilling and generation. + + This class creates a fixed size cache of size max_tokens or sum of prompt sizes + + the max number of generated tokens per sequence. + """ + self.model = model + self.tokenizer = tokenizer + self.temperature = cfg.temperature + self.top_p = cfg.top_p + self.top_k = cfg.top_k + + self.max_gen_len = cfg.max_gen_len + self.max_tokens = cfg.max_tokens + self.max_prompt_len = cfg.max_prompt_len + self.until = cfg.until + self.max_until_size = max([len(e) for e in self.until]) if self.until else 1 + self.device = cfg.device + + # Compile if necessary + self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling) + self.generate_next_token = torch.compile( + self.generate_next_token, + mode="reduce-overhead", + disable=not cfg.reduce_generation_overhead, + ) + + self.show_progress = cfg.show_progress + self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype] + + self.prefill_doc_id, self.prefill_tok_id = None, None + self.padded_doc_id, self.padded_tok_id = None, None + self.current_doc_id, self.current_tok_id = None, None + self.padded_doc_start = None + self.prefill_mask = None + + def clear_cache(self, offset): + for module in self.model.modules(): + if isinstance(module, Attention): + if not hasattr(module, "kv_cache"): + module.kv_cache = KVCache( + 1, + self.max_tokens, + module.n_kv_heads, + module.head_dim, + self.dtype, + self.device, + ) + module.kv_cache.offset = offset + + @torch.compiler.disable + def setup_prefilling(self, lengths: torch.Tensor): + # The KV cache is a fixed size tensor of size max_tokens that we need + # to update in order to do correct autoregressive generation. + + # Here we will generate token by token but on multiple sequences + # at once. To do so, we need to have an attention mask that makes + # each sequence independent. + + # Each sequence will write to its allocated space in the KV Cache. + # We allocate len(seq) + max_gen_len to each sequence in the cache. + + # We will generate max_gen_len for each document + padded_lengths = lengths + self.max_gen_len + max_tokens = self.max_tokens or padded_lengths.sum().item() + # The last document might have more padding to fill up to max_tokens + padded_lengths[-1] += max_tokens - padded_lengths.sum() + + # This is the start index in the cache for each document + self.padded_doc_start = lengths_to_start_ids(padded_lengths) + # For example with ab--123--cdef-- + # this would be 0, 4, 9 if max_gen_len is 2 + + # We repeat interleave to align with tokens for prefilling + # Ex: ab--123--cdef-- + # 000044444999999 + prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths) + # This offset will make sure the tokens are written to the + # correct positions in the cache during prefilling + + # We either init the cache or clear it by resetting the offset to prefill_offset + self.clear_cache(prefill_offset) + + # The prefilling mask looks like the following for + # the two packed sequences ab and 123 : ab123 + # Where spaces are empty cache positions + # keys + # ab---123--- + # queries a 10000000000 + # b 11000000000 + # 1 00000100000 + # 2 00000110000 + # 3 00000111000 + # We make sure to skip the empty cache positions + # and only attend to positions within the same sequence + doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths) + self.prefill_mask = create_block_mask( + doc_mask_mod, 1, None, lengths.sum(), max_tokens + ) + + # This creates the prefilling token ids which look like + # the following for the packed sequence abcdefg1234 + # abcdefg1234 + # 01234560123 + # The token id gives us the position within each sequence + # This is used to compute ROPE and to update the cache + # At each forward pass the current tokens are written to + # offset + tok_id + self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths) + + # This creates the padded token and document ids + # which look like the following for the packed sequence ab123 + # ab---123--- ab---123--- + # padded_doc_id 00000111111 padded_tok_id 01234012345 + # This will later be useful for the attention mask at generation + self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths) + + @torch.compiler.disable + def setup_generation(self, lengths): + # KV Cache offset is set to the start of the padded documents + for module in self.model.modules(): + if isinstance(module, Attention): + module.kv_cache.offset = self.padded_doc_start + # The token ids during generations correspond to the lengths of each doc + # current_tok_id will be incremented during generation + self.current_tok_id = lengths.clone() + # Since we're generating one token per document + # the document id is just an arange + self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device) + + # From here on some methods for generation + def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor): + # Prefilling is done by taking multiple packed sequences and + # doing block diagonal attention on them so they remain independent + self.setup_prefilling(lengths=lengths) + prefill_out = self.model.forward( + tokens, + tok_idx=self.prefill_tok_id, + mask=self.prefill_mask, + attn_impl="flex_attention", + ) + self.setup_generation(lengths=lengths) + return prefill_out + + def generate_next_token(self, current_token): + # Since we're doing generation with multiple sequences at once + # we need to ignore tokens and cache entries from other sequences + # or in the future. + # Example mask : + # keys + # abc--1234-- + # queries c 11100000000 + # 4 00000111100 + + # mask shape : (n_seqs, cache_size) + doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0) + caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0) + mask = doc_mask & caus_mask + out = self.model.forward( + current_token, + tok_idx=self.current_tok_id, # n_seqs + mask=mask, + attn_impl="sdpa", + ) + self.current_tok_id += 1 + return out + + @torch.inference_mode() + def generate(self, prompts): + # Tokenize + prompts = [ + self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts + ] + # Truncate + max_seqlen = ( + self.max_tokens + if not hasattr(self.model, "max_seqlen") + else self.model.max_seqlen + ) + max_prompt_len = self.max_prompt_len or min( + max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len + ) + prompts = [p[-max_prompt_len:] for p in prompts] + # Account for the generation in lengths + padded_lengths = [len(p) + self.max_gen_len for p in prompts] + generation = [] + loglikelihood = [] + greedy = [] + it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths) + if self.show_progress: + it = tqdm(it) + for batch in it: + n_seqs = len(batch) + generated_tokens = [[] for _ in range(n_seqs)] + is_done = [False for _ in range(n_seqs)] + packed_batch, lengths = pack_prompts(batch) + packed_batch, lengths = packed_batch.cuda(), lengths.cuda() + n_seqs = lengths.size(0) + + # Prefilling cache + prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths) + # Selecting last token in each prompt + all_tokens = sample_tokens( + prompt_logits, self.temperature, self.top_p, self.top_k + ) + start_token = all_tokens[:, lengths.cumsum(0) - 1] + + for seq_id, tok in enumerate(start_token.squeeze(0).tolist()): + generated_tokens[seq_id].append(tok) + + current_token = start_token + for i in range(1, self.max_gen_len): + + next_logits = self.generate_next_token(current_token) + next_token = sample_tokens( + next_logits.clone(), self.temperature, self.top_p, self.top_k + ) + + for seq_id, tok in enumerate(next_token.squeeze(0).tolist()): + if not is_done[seq_id]: + generated_tokens[seq_id].append(tok) + current_end_str = self.tokenizer.decode( + generated_tokens[seq_id][-self.max_until_size :] + ) + contains_end_string = any( + [e in current_end_str for e in self.until] + ) + is_done[seq_id] = ( + contains_end_string or tok == self.tokenizer.eos_id + ) + if all(is_done): + break + + current_token = next_token + + generation.extend([self.tokenizer.decode(g) for g in generated_tokens]) + + for p, logit in zip( + batch, prompt_logits.squeeze(0).split(lengths.tolist()) + ): + x = logit[:-1] + y = torch.tensor(p[1:], device=x.device) + loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu()) + greedy.append((x.argmax(dim=-1) == y).cpu()) + + return generation, loglikelihood, greedy + + +def load_consolidated_model_and_tokenizer( + consolidated_path, + model_cls=LMTransformer, + model_args_cls=LMTransformerArgs, +): + ckpt_path = Path(consolidated_path) + config = ckpt_path / "params.json" + config = OmegaConf.load(config) + + param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ + config.distributed.model_dtype + ] + model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) + tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) + model = model_cls(model_args) + st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + model.load_state_dict(st_dict["model"]) + model = model.cuda().eval() + for param in model.parameters(): + param.data = param.data.to(dtype=param_dtype) + return model, tokenizer, config + + +def main(): + # Load CLI arguments (overrides) and combine with a YAML config + cfg = OmegaConf.from_cli() + gen_cfg = dataclass_from_dict( + PackedCausalTransformerGeneratorArgs, cfg, strict=False + ) + print(cfg) + + model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt) + + generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer) + + # Allow multiple prompts + prompts = [] + while True: + prompt = input("Enter a prompt (or press enter to finish): ") + if not prompt: + break + prompts.append(prompt) + + # Start generation + start_time = time.time() + generation, loglikelihood, greedy = generator.generate(prompts) + end_time = time.time() + + # Calculate tokens per second + total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation) + tokens_per_second = total_tokens / (end_time - start_time) + + # Display the results + for i, gen in enumerate(generation): + print(f"\nPrompt {i+1}: {prompts[i]}") + print(f"Generated Text: {gen}") + + print(f"\nTokens per second: {tokens_per_second:.2f}") + + +if __name__ == "__main__": + main() diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py new file mode 100644 index 0000000..bdb47da --- /dev/null +++ b/apps/main/lingua_train.py @@ -0,0 +1,654 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import gc +import logging +import os +import sys +from contextlib import ExitStack +from copy import deepcopy +from dataclasses import asdict, dataclass, field +from pathlib import Path +from timeit import default_timer as timer +from typing import Any, Dict, Optional + +import torch +import torch.distributed +import wandb +import xformers.profiler +from lingua.args import dataclass_from_dict, dump_config, flatten_dict +from lingua.data import ( + DataArgs, + PackTokensState, + build_dataloader_from_args, + init_dataloader_state_from_args, +) +from lingua.tokenizers.build_tokenizer import TokenizerArgs +from omegaconf import OmegaConf +from pydantic import BaseModel +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim import lr_scheduler + +from bytelatent.checkpoint import ( + CheckpointArgs, + CheckpointManager, + load_from_checkpoint, +) +from bytelatent.distributed import ( + DistributedArgs, + EnvironmentArgs, + check_model_value_range, + clean_env, + dist_mean_dict, + get_device_mesh, + get_is_master, + get_world_size, + init_signal_handler, + parallelize_model, + requeue_slurm_job, + setup_env, + setup_torch_distributed, +) +from bytelatent.logger import init_logger +from bytelatent.metrics import ( + GPUMemoryMonitor, + LoggingArgs, + MetricLogger, + get_num_params, +) +from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.probe import AutoProbeD +from bytelatent.profiling import ProfilerArgs, maybe_run_profiler +from bytelatent.stool import StoolArgs, launch_job +from bytelatent.transformer import ( + LMTransformer, + LMTransformerArgs, + build_fsdp_grouping_plan, + get_no_recompute_ops, + get_num_flop_per_token, + tp_parallelize, +) + +logger = logging.getLogger() + + +class TrainArgs(BaseModel): + name: str = "lingua" + dump_dir: str = "" + + seed: int = 42 + + # Number of gradient accumulation steps + # Total batch size is batch_size*grad_acc_steps + grad_acc_steps: int = 1 + + gc_collect_freq: int = 1000 + probe_freq: int | None = None + + # Nb optimizer steps to take + steps: int = 1000 + + data: DataArgs + optim: OptimArgs + model: LMTransformerArgs + distributed: DistributedArgs + env: EnvironmentArgs + + checkpoint: CheckpointArgs + profiling: ProfilerArgs + logging: LoggingArgs + + # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus + async_eval_gpus: int | None = None + eval: Any | None = None + + +@dataclass +class TrainState(Stateful): + step: int # Nb of steps taken by the optimizer + acc_step: int # Nb of accumulation steps done since last optimizer step + scheduler: lr_scheduler.LambdaLR + data_loader_state: PackTokensState + + def state_dict(self) -> Dict[str, Any]: + return { + "step": self.step, + "acc_step": self.acc_step, + "data_loader_state": self.data_loader_state, + "scheduler": self.scheduler.state_dict(), + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.acc_step = state_dict["acc_step"] + self.data_loader_state = PackTokensState(**state_dict["data_loader_state"]) + self.scheduler.load_state_dict(state_dict["scheduler"]) + + +def validate_train_args(args: TrainArgs, output_size: int): + if args.model.vocab_size < 0: + logger.info(f"Setting model output size to {args.model.vocab_size}") + args.model.vocab_size = output_size + assert ( + args.model.vocab_size == output_size + ), "Vocab size should be the same as output size" + + assert args.dump_dir, "Dump dir not set" + + if args.checkpoint.path is None: + logger.info(f"Setting checkpoint path to {args.checkpoint.path}") + args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + + for source in args.data.sources: + data_path = os.path.join(args.data.root_dir, source) + assert os.path.exists(data_path), f"{data_path} doesn't exist" + + if ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + != get_world_size() + ): + assert get_world_size() % args.distributed.dp_shard == 0 + args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + + assert args.distributed.dp_replicate % args.distributed.tp_size == 0 + args.distributed.dp_replicate = ( + args.distributed.dp_replicate // args.distributed.tp_size + ) + + logger.warning( + f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" + ) + assert ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + == get_world_size() + ) + + if args.distributed.fsdp_type == "no_shard": + assert ( + args.distributed.dp_shard == 1 + and args.distributed.dp_replicate == get_world_size() + ) + + args.model.max_seqlen = args.data.seq_len + + if args.distributed.tp_size == 1: + logger.warning( + "Tensor parallelism has not been tested for a while, use at your own risk" + ) + + assert ( + args.probe_freq != args.profiling.mem_steps + ), "Don't profile during probe step" + assert ( + args.probe_freq != args.profiling.profile_steps + ), "Don't profile during probe step" + if args.logging.wandb is not None: + args.logging.wandb.name = args.name + + if args.probe_freq is not None: + assert ( + args.distributed.tp_size == 1 + ), "Probing not supported with tensor parallelism" + assert ( + args.distributed.selective_activation_checkpointing is False + ), "Probing not supported with selective activation checkpointing" + + +preemption_flag = dict(flag=False) + + +def set_preemption_flag(signum, frame): + logger.warning("Signal handler called with signal " + str(signum)) + logger.warning("Preemption ! checkpointing asap and exiting.") + preemption_flag["flag"] = True + + +def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): + test = train_state.step % freq == 0 + if acc_step is not None: + test = test and (train_state.acc_step == acc_step) + elif acc_freq is not None: + test = test and ((train_state.acc_step % acc_freq) == 0) + return test + + +def train(args: TrainArgs): + with ExitStack() as context_stack: + tokenizer_args = TokenizerArgs( + name=args.data.name, + init_kwargs=args.data.tokenizer.init_kwargs, + ) + tokenizer = tokenizer_args.build() + validate_train_args( + args, + tokenizer.n_words, + ) + if get_is_master(): + os.makedirs(args.dump_dir, exist_ok=True) + dump_config(args, Path(args.dump_dir) / "config.yaml") + init_logger(Path(args.dump_dir) / "train.log") + init_signal_handler(set_preemption_flag) # For handling preemption signals. + setup_env(args.env) + setup_torch_distributed(args.distributed) + world_mesh = get_device_mesh(args.distributed) + logger.info(f"Starting job: {args.name}") + + # build dataloader + # need dp world size and rank + dp_mesh = world_mesh["dp_replicate"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + if args.distributed.dp_shard > 1: + dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() + dp_degree *= world_mesh["dp_shard"].size() + + logger.info(f"Running on dp rank : {dp_rank}") + logger.info(f"Running on dp size : {dp_degree}") + + torch.manual_seed(args.seed) + logger.info("Building model") + + # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory + with torch.device("meta"): + model = LMTransformer(args.model) + logger.info("Model is built !") + + model_param_count = get_num_params(model) + + model = parallelize_model( + model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + tp_parallelize=tp_parallelize, + no_recompute_ops=get_no_recompute_ops(), + ) + + # Once we shard the model on different gpus we can actually initialize the model + # First we create empty tensors of the correct shapes + model = model.to_empty(device="cuda") + # Then we init the model. Please make sure this function initializes *ALL* parameters + # and buffers, otherwise you will have random values in the unitialized tensors + # which will silently fail (give nan gradients for example) + + if args.checkpoint.init_ckpt_path: + logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + load_from_checkpoint( + args.checkpoint.init_ckpt_path, model, model_key="model" + ) # Put model_key="" if its directly the model checkpoint + model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded + else: + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(args.model.seed) + model.init_weights() + check_model_value_range(model, range=10.0, std=1.0) + + # log model size + + logger.info(f"Model size: {model_param_count:,} total parameters") + + gpu_memory_monitor = GPUMemoryMonitor("cuda") + logger.info( + f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " + f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" + ) + logger.info(f"GPU memory usage: {gpu_memory_monitor}") + + # build optimizer after apply parallelisms to the model + optimizer, scheduler = build_optimizer(model, args.optim, args.steps) + data_loader_state = init_dataloader_state_from_args( + args.data, dp_rank, dp_degree + ) + + train_state = TrainState( + step=0, + acc_step=0, + data_loader_state=data_loader_state, + scheduler=scheduler, + ) + + checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) + checkpoint.load(model, optimizer, train_state, world_mesh) + # Either load from latest checkpoint or start from scratch + if args.probe_freq is not None: + if get_is_master(): + os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + torch.distributed.barrier() + probe = AutoProbeD( + model, + ( + Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + if (dp_rank % 128 == 0) + else None + ), + ) + probe_mod = model._orig_mod if args.distributed.compile else model + + gc.disable() + + # train loop + model.train() + metric_logger = context_stack.enter_context( + MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + ) + data_loader = context_stack.enter_context( + build_dataloader_from_args( + args.data, + state=train_state.data_loader_state, + ) + ) + torch_profiler = context_stack.enter_context( + maybe_run_profiler(args.dump_dir, model, args.profiling) + ) + + nwords_since_last_log = 0 + time_last_log = timer() + gc.collect() + while train_state.step < args.steps: + # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 + train_state.acc_step += 1 + train_state.acc_step = train_state.acc_step % args.grad_acc_steps + + # get batch + curr_lr = float(optimizer.param_groups[0]["lr"]) + data_load_start = timer() + batch, train_state.data_loader_state = next(data_loader) + batch = torch.tensor( + batch, + dtype=torch.long, + ) + + if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): + logger.info("garbage collection") + # we do garbage collection manually otherwise different processes + # run the GC at different times so they slow down the whole pipeline + gc.collect() + + input_ids = batch[:, :, 0].cuda() + labels = batch[:, :, 1].cuda() + data_load_time = round(timer() - data_load_start, 4) + nwords_since_last_log += input_ids.numel() + + bsz, seqlen = labels.shape + + # forward + start_timer = torch.cuda.Event(enable_timing=True) + end_timer = torch.cuda.Event(enable_timing=True) + start_timer.record() + + # This is an automatic probe that will compute statistics + # of all linears' inputs, weights and outputs + # along with attention logits and entropy + # both in forward and backward pass + if (args.probe_freq is not None) and every_n_steps( + train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps + ): + # Here we do a fake forward and backward pass on a smaller + # batch size to avoid OOM + # This assumes the model has no stateful layers (batch norm..) + assert ( + next(probe_mod.parameters()).grad is None + ), "Can't probe model if grads are not reset" + + with probe: + probe.metadata = { + "it": train_state.step, + "global_step": train_state.step, + "loop": "lingua", + } + # Non compiled model uses roughly 2x memory in our exps + # So we divide bsz by 2 or seqlen by 2 + probe_bsz = max(1, bsz // 2) + probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2) + probe_loss = probe_mod( + input_ids[:probe_bsz, :probe_seq], + labels[:probe_bsz, :probe_seq], + ) + probe_loss.backward() + # We zero grads to cancel this fake step + optimizer.zero_grad() + + assert ( + next(probe_mod.parameters()).grad is None + ), "Probe model shouldn't have grads at this point" + loss = model(input_ids, labels) + + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = loss / args.grad_acc_steps + # backward on scaled loss to create scaled gradients + loss.backward() + # For logging we undo that scaling + loss = loss.detach() * args.grad_acc_steps + + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + + grad_norm = ( + grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm + ).item() + + # optimizer step + if train_state.acc_step == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + train_state.step += 1 + + # updates the scale for next iteration + # training iteration complete + end_timer.record() + + torch.cuda.synchronize() + + curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) + + # if profiler is active + if torch_profiler: + xformers.profiler.step() + + # log metrics + if every_n_steps( + train_state, + args.logging.freq, + acc_step=None if args.logging.acc_freq else 0, + acc_freq=args.logging.acc_freq, + ): + time_delta = timer() - time_last_log + wps = nwords_since_last_log / (time_delta * args.distributed.tp_size) + + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + + total_acc_steps = ( + args.grad_acc_steps * train_state.step + train_state.acc_step + ) + tokens_per_gpu = ( + total_acc_steps * args.data.batch_size * args.data.seq_len + ) + total_tokens = dp_degree * tokens_per_gpu + # This is an estimate and the correct values may change + # if you change the architecture + # Use xformer's analyze profile trace to get actual measurement + FLOPS = ( + get_num_flop_per_token( + model_param_count - args.model.vocab_size * args.model.dim, + args.model.n_layers, + args.model.dim, + args.data.seq_len, + ) + * wps + ) + metrics = flatten_dict( + { + "global_step": train_state.step, + "acc_step": train_state.acc_step, + "speed": { + "wps": wps, + "FLOPS": FLOPS, + "curr_iter_time": curr_iter_time, + "data_load_time": data_load_time, + }, + "optim": { + "grad_norm": grad_norm, + "lr": curr_lr, + "total_tokens": total_tokens, + }, + "memory": gpu_mem_stats._asdict(), + }, + sep="/", + ) + + to_sync = {} + to_sync["loss/out"] = loss.item() + metrics.update(dist_mean_dict(to_sync)) + + if get_is_master(): + metric_logger.log(metrics) + + gpu_memory_monitor.reset_peak_stats() + nwords_since_last_log = 0 + time_last_log = timer() + logger.info( + f"step: {train_state.step}" + f" acc: {train_state.acc_step}" + f" loss: {round(loss.item(),4):>7}" + f" grad: {grad_norm:.2e}" + f" flops: {FLOPS:.2e}" + f" wps: {wps:.2e}" + f" iter: {curr_iter_time:>7}" + f" data: {data_load_time:>5}" + f" lr: {curr_lr:.2e}" + f" mem: {gpu_mem_stats.max_active_pct:.0f}%" + f" pow: {gpu_mem_stats.power_draw/1000} W" + ) + + saved = False + if every_n_steps( + train_state, args.checkpoint.dump.every, acc_step=0 + ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + saved = checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + + if args.eval is not None and every_n_steps( + train_state, args.checkpoint.eval.every, acc_step=0 + ): + from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + + eval_args = dataclass_from_dict(EvalArgs, args.eval) + + eval_args.global_step = train_state.step + eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) + eval_args.dump_dir = str( + os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), + ) + ) + eval_args.metric_log_dir = args.dump_dir + if args.async_eval_gpus is None: + launch_eval(eval_args) + elif get_is_master(): + if wandb.run is not None and args.logging.wandb is not None: + eval_args.wandb = deepcopy(args.logging.wandb) + assert args.async_eval_gpus > 0 + logger.info(f"Launching evals on {args.async_eval_gpus} gpus") + with clean_env(): + launch_job( + StoolArgs( + asdict(eval_args), + script="apps.main.eval", + copy_code=False, + nodes=args.async_eval_gpus // 8, + qos="lowest", + ) + ) + + if preemption_flag["flag"]: + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + requeue_slurm_job() + sys.exit(0) + + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + gc.collect() + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate TrainArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call train.py with train.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in TrainArgs dataclass. + """ + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.structured(TrainArgs()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_object(cfg) + + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/blt-figure.jpg b/blt-figure.jpg new file mode 100644 index 0000000..e27edc6 Binary files /dev/null and b/blt-figure.jpg differ diff --git a/blt-figure.pdf b/blt-figure.pdf new file mode 100644 index 0000000..045f15d Binary files /dev/null and b/blt-figure.pdf differ diff --git a/bytelatent/.DS_Store b/bytelatent/.DS_Store new file mode 100644 index 0000000..5a50b07 Binary files /dev/null and b/bytelatent/.DS_Store differ diff --git a/bytelatent/__init__.py b/bytelatent/__init__.py new file mode 100644 index 0000000..5bc057f --- /dev/null +++ b/bytelatent/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +class ByteLatentError(Exception): + pass diff --git a/bytelatent/args.py b/bytelatent/args.py new file mode 100644 index 0000000..4fba100 --- /dev/null +++ b/bytelatent/args.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +import os +from typing import Any + +import numpy as np +import yaml +from pydantic import BaseModel, ConfigDict + +from bytelatent.checkpoint import CheckpointArgs +from bytelatent.data.data_types import Batch +from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.arrow_iterator import ( + ArrowFileIterator, + find_and_sanitize_chunks, +) +from bytelatent.data.iterators.looping_iterator import LoopingIterator +from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator +from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator +from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator +from bytelatent.data.iterators.sampling_iterator import SamplingIterator +from bytelatent.data.iterators.sequence_iterator import ( + SequenceIterator, + SequencePackingArgs, +) +from bytelatent.data.patcher import PatcherArgs +from bytelatent.distributed import DistributedArgs, EnvironmentArgs +from bytelatent.metrics import LoggingArgs +from bytelatent.model.blt import ByteLatentTransformerArgs +from bytelatent.optim import OptimArgs +from bytelatent.profiling import ProfilerArgs +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + +logger = logging.getLogger() + + +def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: + return np.random.default_rng((seed, rank, world_size)).bit_generator.state + + +def distribute_data_to_rank( + *, + dataset_path: str, + preprocess_dir: str, + entropy_model_name: str | None, + arrow_batch_size: int, + rank: int, + world_size: int, +) -> ArrowFileIterator: + dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + n_workers_per_chunk = world_size // len(dataset_chunks) + rank_to_arrow_iterator_params = [] + for chunk_path in dataset_chunks: + for worker_id in range(n_workers_per_chunk): + rank_to_arrow_iterator_params.append( + ArrowFileIterator( + file_path=chunk_path, + worker_id=worker_id, + num_workers=n_workers_per_chunk, + preprocess_dir=preprocess_dir, + dataset_files=None, + entropy_model_name=entropy_model_name, + arrow_batch_size=arrow_batch_size, + ) + ) + return rank_to_arrow_iterator_params[rank] + + +class DataloaderArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + root_dir: str | None = None + sources: dict[str, float] = {} + batch_size: int = 2 + seq_len: int = 2048 + seed: int = 42 + add_bos: bool = True + add_eos: bool = True + load_async: bool = True + prefetch_size: int = 64 + preprocess_dir: str | None = None + dataset_files: list[str] | None = None + entropy_model_name: str | None = "transformer_100m" + arrow_batch_size: int = 100 + buffer_size: int = 64 + + pad_to_max_length: bool = True + max_encoder_seq_length: int = 12288 + enable_byte_ngrams: bool = False + + tokenizer_args: TokenizerArgs = TokenizerArgs() + patcher_args: PatcherArgs = PatcherArgs() + + def _create_sequence_iterators( + self, rank: int, world_size: int + ) -> dict[str, SequenceIterator]: + sequence_packing_args = SequencePackingArgs( + output_seq_len=self.seq_len, + buffer_size=self.buffer_size, + ) + source_to_sequence_iterator: dict[str, SequenceIterator] = {} + for dataset_path in self.sources: + shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size) + arrow_iterator = distribute_data_to_rank( + dataset_path=os.path.join(self.root_dir, dataset_path), + preprocess_dir=self.preprocess_dir, + entropy_model_name=self.entropy_model_name, + arrow_batch_size=self.arrow_batch_size, + rank=rank, + world_size=world_size, + ) + looping_iterator = LoopingIterator(arrow_iterator) + preprocess_iterator = PreprocessIterator( + looping_iterator, + patcher_args=self.patcher_args, + tokenizer_args=self.tokenizer_args, + ) + sequence_iterator = SequenceIterator( + preprocess_iterator, + sequence_packing_args=sequence_packing_args, + rng_state=shuffle_rng_state, + ) + + source_to_sequence_iterator[dataset_path] = sequence_iterator + return source_to_sequence_iterator + + def build_from_rank( + self, rank: int, world_size: int + ) -> StatefulIterator[Batch, Any]: + source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size) + weight_rng_state = get_rng_state(self.seed + 1, rank, world_size) + sampling_iterator = SamplingIterator( + rng_state=weight_rng_state, + source_to_weight=self.sources, + source_to_iterator=source_to_sequence_iterators, + ) + tokenizer = self.tokenizer_args.build() + packing_args = PackingArgs( + batch_size=self.batch_size, + seq_len=self.seq_len, + pad_id=tokenizer.boe_id, + max_length=self.max_encoder_seq_length, + pad_to_max_length=self.pad_to_max_length, + enable_byte_ngrams=self.enable_byte_ngrams, + ) + packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) + mp_iterator = MultiprocessIterator( + packing_iterator, n_batches_to_prefetch=self.prefetch_size + ) + + return mp_iterator + + +class TrainArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + name: str = "lingua" + dump_dir: str = "" + + seed: int = 42 + + # Number of gradient accumulation steps + # Total batch size is batch_size*grad_acc_steps + grad_acc_steps: int = 1 + + gc_collect_freq: int = 1000 + probe_freq: int | None = None + + # Nb optimizer steps to take + steps: int = 1000 + + data: DataloaderArgs = DataloaderArgs() + optim: OptimArgs = OptimArgs() + model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + distributed: DistributedArgs = DistributedArgs() + env: EnvironmentArgs = EnvironmentArgs() + + checkpoint: CheckpointArgs = CheckpointArgs() + profiling: ProfilerArgs = ProfilerArgs() + logging: LoggingArgs = LoggingArgs() + + # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus + async_eval_gpus: int | None = None + eval: Any | None = None + eval_on_gpus: int | None = None + + def dump_to_yaml_file( + self, path: str, log_config: bool = True, sort_keys: bool = True + ): + model_dict = self.model_dump(mode="json") + yaml_str = yaml.dump( + model_dict, + allow_unicode=True, + sort_keys=sort_keys, + default_flow_style=False, + ) + with open(path, "w") as f: + if log_config: + logger.info("Using the following config for this run:") + logger.info(yaml_str) + f.write(yaml_str) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py new file mode 100644 index 0000000..f494a15 --- /dev/null +++ b/bytelatent/base_transformer.py @@ -0,0 +1,585 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from enum import Enum +from typing import Optional, Tuple, Union + +import torch +from pydantic import BaseModel +from torch import nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import ( + BlockMask, + _mask_mod_signature, + flex_attention, +) +from xformers.ops import AttentionBias, fmha + +from bytelatent import probe + +flex_attention_comp = torch.compile(flex_attention) + + +class InitStdFactor(Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 + + +class BaseTransformerArgs(BaseModel): + dim: int = 512 + n_layers: int = 8 + head_dim: Optional[int] = None + n_heads: Optional[int] = None + n_kv_heads: Optional[int] = None + + ffn_dim_multiplier: Optional[float] = None + + multiple_of: int = 256 + + norm_eps: float = 1e-5 + + rope_theta: float = 10000.0 + + init_base_std: Optional[float] = None + init_std_factor: InitStdFactor = InitStdFactor.DISABLED + + max_seqlen: int = 1024 + + +def cross_entropy(pred, target, **kwargs): + return F.nll_loss( + F.log_softmax(pred.flatten(end_dim=-2).float(), -1), + target.flatten(end_dim=-1), + **kwargs, + ) + + +def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [ + d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) + ] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast( + freqs_cis, xq_, seq_dim + ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def lengths_to_start_ids(lengths): + doc_start = lengths.cumsum(0) + doc_start = doc_start.roll(1) + doc_start[0] = 0 + return doc_start + + +def lengths_to_local_ids(lengths): + assert lengths.ndim == 1 + nb_seqs = lengths.size(0) + total_seqlen = lengths.sum() + # This gives the document id of each token + doc_id = torch.repeat_interleave(lengths) + # Compute document start for each document + doc_start = lengths_to_start_ids(lengths) + # Compute document start for each token + doc_start = doc_start[doc_id] + # Compute the position of each token within each document + tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start + + return doc_id, tok_id + + +def generate_doc_mask_mod( + mask_mod: _mask_mod_signature, + lengths: torch.Tensor, + kv_lengths: Optional[torch.Tensor] = None, +) -> _mask_mod_signature: + """Generates mask mods that apply to inputs to flex attention in the sequence stacked + format. + + Args: + mask_mod: The mask mod to apply to the documents + lengths: Lengths of each document + + Note: + What is the sequence stacked format? When assembling batches of inputs, we + take multiple sequences and stack them together to form 1 large sequence. We then + use masking to ensure that the attention scores are only applied to tokens within + the same document. + + Example: + + - Square mask + doc_mask lengths + a a b b b c c 2 3 2 + a 1 0 0 0 0 0 0 + a 1 1 0 0 0 0 0 + b 0 0 1 0 0 0 0 + b 0 0 1 1 0 0 0 + b 0 0 1 1 1 0 0 + c 0 0 0 0 0 1 0 + c 0 0 0 0 0 1 1 + + """ + kv_lengths = kv_lengths if kv_lengths is not None else lengths + q_document_id, q_token_id = lengths_to_local_ids(lengths) + kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths) + q_max_idx = lengths.sum() - 1 + kv_max_idx = kv_lengths.sum() - 1 + + def doc_mask_mod(b, h, q_idx, kv_idx): + q_idx_cap = torch.minimum(q_max_idx, q_idx) + kv_idx_cap = torch.minimum(kv_max_idx, kv_idx) + valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx) + same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap] + q_logical = q_token_id[q_idx_cap] + kv_logical = kv_token_id[kv_idx_cap] + inner_mask = mask_mod(b, h, q_logical, kv_logical) + return same_doc & inner_mask & valid_idx + + return doc_mask_mod + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta), + persistent=False, + ) + + def reset_parameters(self): + self.freqs_cis[...] = precompute_freqs_cis( + dim=self.head_dim, end=self.max_seqlen, theta=self.theta + ) + + def forward( + self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None + ): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + x = probe.log_stats(x, "resid") + output = self._norm(x.float()) + return (output * self.weight.float()).type_as(x) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "fmha": + assert mask is None or isinstance(mask, AttentionBias) + output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + # This uses B S H D instead of B H S D of pytorch + + elif attn_impl == "sdpa": + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError( + f"Attention implementation {attn_impl} not supported" + ) + + output = self.wo(output.reshape(output_shape)) + + return output + + def reset_parameters(self, init_std=None, factor=1.0): + init_std = init_std or (self.dim ** (-0.5)) + + for w in [self.wq, self.wk, self.wv]: + nn.init.trunc_normal_( + w.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=init_std / factor, + a=-3 * init_std, + b=3 * init_std, + ) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + mp_size: int = 1, + ): + super().__init__() + + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + assert hidden_dim % mp_size == 0 + + self.dim = dim + self.hidden_dim = hidden_dim + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # B S D + x1 = self.w1(x.view_as(x)) + x3 = self.w3(x.view_as(x)) + output = self.w2(F.silu(x1) * x3) + return output + + def reset_parameters(self, init_std=None, factor=1.0): + in_init_std = init_std or (self.dim ** (-0.5)) + out_init_std = init_std or (self.hidden_dim ** (-0.5)) + in_init_std = in_init_std + out_init_std = out_init_std / factor + for w in [self.w1, self.w3]: + nn.init.trunc_normal_( + w.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + nn.init.trunc_normal_( + self.w2.weight, + mean=0.0, + std=out_init_std, + a=-3 * out_init_std, + b=3 * out_init_std, + ) + + +class TransformerBlock(nn.Module): + def __init__(self, args: BaseTransformerArgs): + super().__init__() + + assert (args.head_dim is not None) or ( + args.n_heads is not None + ), "Should specify at least head_dim or n_heads" + self.head_dim = args.head_dim or args.dim // args.n_heads + self.n_heads = args.n_heads or args.dim // args.head_dim + self.n_kv_heads = args.n_kv_heads or self.n_heads + + assert args.n_heads % self.n_kv_heads == 0 + assert args.dim % args.n_heads == 0 + + self.attention = Attention( + dim=args.dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=args.rope_theta, + ) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + h = x + self.attention( + self.attention_norm(x), + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, + ) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self, init_std=None, factor=1.0): + self.attention.reset_parameters(init_std, factor) + self.attention_norm.reset_parameters() + + self.feed_forward.reset_parameters(init_std, factor) + self.ffn_norm.reset_parameters() + + +class BaseTransformer(nn.Module): + def __init__(self, args: BaseTransformerArgs): + super().__init__() + self.dim = args.dim + self.init_base_std = args.init_base_std + self.init_std_factor = InitStdFactor(args.init_std_factor) + self.max_seqlen = args.max_seqlen + self.rope_embeddings = RotaryEmbedding( + theta=args.rope_theta, + head_dim=args.head_dim or args.dim // args.n_heads, + max_seqlen=args.max_seqlen, + ) + + self.layers = nn.ModuleList() + for _ in range(args.n_layers): + self.layers.append(TransformerBlock(args)) + + def forward( + self, + h, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ): + + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + return h + + def reset_parameters(self): + # Either use fixed base std or sqrt model dim + self.rope_embeddings.reset_parameters() + + def init_weights(self): + self.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(self.init_base_std, factor) diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py new file mode 100644 index 0000000..bcf591e --- /dev/null +++ b/bytelatent/checkpoint.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +import logging +import os +import re +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import torch.nn as nn +import torch.optim.optimizer +from pydantic import BaseModel, ConfigDict +from torch.distributed._tensor import DeviceMesh +from torch.distributed.checkpoint.format_utils import dcp_to_torch_save +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + get_state_dict, + set_state_dict, +) + +from bytelatent.distributed import get_is_master + +logger = logging.getLogger("CHECKPOINT") + +FOLDER_NAME = "{:010d}" +RE_FOLDER = r"\d{10}" + +RE_CKPT = r"__\d_\d\.distcp" + +CONSOLIDATE_FOLDER = "consolidated" +CONSOLIDATE_NAME = "consolidated.pth" + +CONFIG_NAME = "params.json" +TRAIN_STATE_NAME = "train_state_{:05d}.json" +RE_DIGITS = re.compile(r"\d+") + + +class SaveEvery(BaseModel): + model_config = ConfigDict(extra="forbid") + every: int = 1000 + keep: int = 0 + + +class CheckpointArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dump: SaveEvery = SaveEvery() + eval: SaveEvery = SaveEvery() + path: str | None = None + init_ckpt_path: str | None = None + continue_training_from_init: bool = False + + +def _get_key_step(name: str): + return int(re.findall(RE_DIGITS, name)[-1]) + + +def consolidate_checkpoints(ckpt_dir: str): + """ + Consolidates all FSDP checkpoints in a directory to a single file + Consolidate checkpoint is saved in a subdirectory of ckpt_dir + + Parameters: + ckpt_dir: str - path to the directory containing the checkpoints + + Returns the path to the consolidated checkpoint + """ + consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER + if not (consolidate_path / CONSOLIDATE_NAME).exists(): + consolidate_path.mkdir(exist_ok=True) + logger.info(f"Consolidating to: {str(consolidate_path)}") + dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) + (consolidate_path / CONFIG_NAME).write_text( + (Path(ckpt_dir) / CONFIG_NAME).read_text() + ) + logger.info("Consolidated !") + return consolidate_path + + +def load_from_checkpoint( + ckpt_dir: str, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + model_key: str = "model", + optim_key: str = "optim", +): + if not (Path(ckpt_dir) / ".metadata").exists(): + raise ValueError( + f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" + ) + + state_dict = {} + if optimizer is not None: + state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer) + else: + state_dict[model_key] = get_model_state_dict(model) + if model_key == "": # If only loading a model directly, the key should be empty + state_dict = state_dict.pop(model_key) + + dcp.load(state_dict, checkpoint_id=ckpt_dir) + + +class CheckpointManager: + def __init__(self, args: CheckpointArgs): + self.path = args.path + self.dump_every = args.dump + self.eval_every = args.eval + self.init_ckpt_path = args.init_ckpt_path + self.continue_training_from_init = args.continue_training_from_init + + assert os.path.exists( + self.path + ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" + + self.existing_saves = self.get_existing_saves() + + def get_existing_saves(self) -> List[Path]: + folders = [ + p + for p in Path(self.path).iterdir() + if p.is_dir() and re.match(RE_FOLDER, p.name) + ] + folders.sort(key=lambda p: _get_key_step(p.name)) + return folders + + def clean_up(self): + logger.info("Cleaning up checkpoints...") + dump_folders = [] + eval_folders = [] + other_folders = [] + for p in self.existing_saves: + is_dump = _get_key_step(p.name) % self.dump_every.every == 0 + is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + if is_dump: + dump_folders.append(p) + if is_eval: + eval_folders.append(p) + if not (is_dump or is_eval): + other_folders.append(p) + + logger.info(f"Dump folders: {dump_folders}") + logger.info(f"Eval folders: {eval_folders}") + logger.info(f"Other folders: {other_folders}") + + if self.dump_every.keep > 0: + dump_folders = dump_folders[-self.dump_every.keep :] + if self.eval_every.keep > 0: + eval_folders = eval_folders[-self.eval_every.keep :] + + folder_to_keep = set(other_folders + dump_folders + eval_folders) + folder_to_remove = set(self.existing_saves) - folder_to_keep + + logger.info(f"Removing folders: {folder_to_remove}") + + if dist.get_rank() == 0: + for folder in folder_to_remove: + for file in folder.iterdir(): + if file.is_file(): + file.unlink() + elif file.is_dir(): + assert file.name in [CONSOLIDATE_FOLDER] + for f in file.iterdir(): + f.unlink() + file.rmdir() + folder.rmdir() + + dist.barrier() + + self.existing_saves = list(folder_to_keep) + self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + + def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + path = None + for p in reversed(self.existing_saves): + if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + path = p + break + return path + + def _create_folder(self, base_path: Path, folder_name: str) -> Path: + folder = base_path / folder_name + if get_is_master(): + folder.mkdir(parents=False, exist_ok=True) + if dist.is_initialized(): + dist.barrier() + return folder + + def _get_dp_tp_mesh( + self, device_mesh: Optional[DeviceMesh] = None + ) -> Tuple[int, int]: + dp_rank = 0 + tp_rank = 0 + if device_mesh is not None: + if "dp_replicate" in device_mesh.mesh_dim_names: + dp_rank = device_mesh.get_local_rank("dp_replicate") + if "dp_shard" in device_mesh.mesh_dim_names: + dp_rank = dp_rank * device_mesh[ + "dp_replicate" + ].size() + device_mesh.get_local_rank("dp_shard") + if "tp" in device_mesh.mesh_dim_names: + tp_rank = device_mesh.get_local_rank("tp") + return dp_rank, tp_rank + + @torch.no_grad() + def get_state_dict( + self, + model, + optimizer, + ): + model_sd, optim_sd = get_state_dict(model, optimizer) + return {"model": model_sd, "optim": optim_sd} + + def save( + self, + model, + optimizer, + train_state, + config, + device_mesh: Optional[DeviceMesh] = None, + ) -> bool: + + # When creating directory check if only rank0 or is there other solution + path = Path(self.path) + curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) + logger.info(f"Saving to: {str(curr_save_dir)}") + + if dist.is_initialized(): + dist.barrier() + + logger.info("Saving...") + state_dict = self.get_state_dict(model, optimizer) + dcp.save(state_dict, checkpoint_id=curr_save_dir) + logger.info("State dict saved!") + + if dist.is_initialized(): + dist.barrier() + + if get_is_master(): + config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + + # Add json dump here + dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) + if tp_rank == 0: + train_state_name = TRAIN_STATE_NAME.format(dp_rank) + logger.info( + f"Saving train state to: {str(curr_save_dir / train_state_name)}" + ) + with open(curr_save_dir / train_state_name, "w") as f: + json.dump(train_state.state_dict(), f) + logger.info("Train state saved !") + + self.existing_saves.append(curr_save_dir) + + self.clean_up() + + if dist.is_initialized(): + dist.barrier() + return True + + @torch.no_grad() + def load( + self, + model: nn.Module, + optimizer, + train_state, + device_mesh: DeviceMesh, + path: Optional[Path] = None, + ): + dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) + # Loading tries to load the provided path, if not available the last saved step and finally from the init path + path = path or self.get_last_step_path(dp_rank=dp_rank) + # If none of those are available don't do anything + if path is None: + # If no checkpoints exist do nothing + return + + # Only load train state if it's provided, the files exist and we're not loading from init path + train_state_name = TRAIN_STATE_NAME.format(dp_rank) + logger.info("Reloading train state") + with open(path / train_state_name, "r") as f: + train_state_dict = json.load(f) + train_state.load_state_dict(train_state_dict) + logger.info("Train state reloaded") + + logger.info(f"Loading from: {str(path)}") + state_dict = self.get_state_dict( + model=model, + optimizer=optimizer, + ) + dcp.load(state_dict, checkpoint_id=path) + logger.info("State dict loaded.") + + logger.info("Reloading model and optim") + + set_state_dict( + model, + optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + ) + logger.info("Model and optim reloaded") + + @classmethod + def instantiate_and_make_dir(cls, args: CheckpointArgs): + if get_is_master(): + os.makedirs(args.path, exist_ok=True) + dist.barrier() + + return cls(args) diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml new file mode 100644 index 0000000..5f6debb --- /dev/null +++ b/bytelatent/configs/debug.yaml @@ -0,0 +1,110 @@ +# Template config, need to change dump_dir, data.root_dir and tokenizer.path +# Evals can be activated by uncommenting its config +# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest + +dump_dir: /tmp/ +name: "debug" +steps: 100_000 +probe_freq: null +seed: 777 +optim: + lr: 4e-04 + warmup: 500 + lr_min_ratio: 0.1 + clip: 10.0 + +distributed: + fsdp_type: full_shard + compile: true + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +model: + n_heads: 8 + dim: 512 + vocab_size: 260 + dim_token: 256 + patch_size: 6 + tokenization_mode: "bytes" + patching_mode: "space" + tie_local_encoder_decoder_logits: false + data_loader_patching: true + max_encoder_seq_length: 12288 + pad_to_max_length: true + patching_threshold: 3.1439168453216553 + encoder_hash_byte_group_size: [4] + encoder_hash_byte_group_vocab: 50002 + encoder_hash_byte_group_nb_functions: 3 + encoder_enable_byte_ngrams: false + cross_attn_encoder: true # assuming cross_attention is true + cross_attn_decoder: true # assuming cross_attention is true + cross_attn_window_encoder: 512 + cross_attn_window_decoder: 512 + dim_local_encoder: 256 + dim_local_decoder: 256 + cross_attn_k: 8 + cross_attn_nheads: 4 + cross_attn_all_layers_decoder: true + cross_attn_all_layers_encoder: true + cross_attn_use_flex_attention: true + cross_attn_init_by_pooling: true + log_patch_lengths: true + non_linearity: "swiglu" + use_rope: true + recompute_fc1_out: false + recompute_fc3_out: false + recompute_attn: false + custom_bwd: false + layer_ckpt: "none" + efficient_attn: "sdpa" + patch_only_encoder: false + patch_only_decoder: false + use_local_encoder_transformer: true + init_use_gaussian: true + init_use_depth: "current" + attn_bias_type: "block_causal" + alpha_depth: "disabled" + max_length: 256 + local_attention_window_len: 512 + max_seqlen: 12288 + downsampling_by_pooling: "max" + +data: + root_dir: ??? + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 64 + seq_len: 4096 + load_async: true + preprocess_dir: ??? + tokenizer_args: + name: blt + init_kwargs: + bpe_tokenizer_path: ??? + +profiling: + run: false + +checkpoint: + dump: + every: 500 + keep: 3 + eval: + every: 1000 + keep: -1 + +logging: + freq: 10 + +eval_on_gpus: 8 +eval: + dataset_dir: /checkpoint/amaia/codegen/datasets/eval + tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu + generator: + max_tokens: 65536 + dtype: bf16 + + mp_size: 1 diff --git a/bytelatent/constants.py b/bytelatent/constants.py new file mode 100644 index 0000000..341c7ff --- /dev/null +++ b/bytelatent/constants.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +from pathlib import Path + +BLT_DATA = Path(os.environ.get("BLT_DATA", "data")) diff --git a/bytelatent/data/__init__.py b/bytelatent/data/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/data/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py new file mode 100644 index 0000000..7e142e4 --- /dev/null +++ b/bytelatent/data/data_types.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import json +from dataclasses import dataclass +from typing import Any, Iterator + +import numpy as np +from pydantic import BaseModel, ConfigDict + + +class BltExample(BaseModel): + model_config = ConfigDict(extra="forbid") + sample_id: str + text: str + tokens: list[int] | None + entropies: list[float] | None + patch_lengths: list[int] | None + mask: list[bool] | None + + +class MultiChoiceState(BaseModel): + model_config = ConfigDict(extra="forbid") + root_dir: str + sources: dict[str, float] + source_to_state: dict[str, Any] + rng_state: dict[str, Any] + + +class PrefetchState(BaseModel): + model_config = ConfigDict(extra="forbid") + seq_idx: int + rng_state: dict[str, Any] + prefetch_size: int + batch_size: int + + +class BltPackTokensState(BaseModel): + model_config = ConfigDict(extra="forbid") + start_token: int + output_seq_len: int + n_views: int = 2 + + +class DataLoaderState(BaseModel): + model_config = ConfigDict(extra="forbid") + multi_choice_state: MultiChoiceState + pack_tokens_state: BltPackTokensState + prefetch_state: PrefetchState + + +BltIterator = Iterator[tuple[BltExample, DataLoaderState]] + + +class BltSequence(BaseModel): + tokens: list[int] + mask: list[bool] + patch_lengths: list[int] + + +@dataclass +class Batch: + x: np.ndarray + y: np.ndarray + mask: np.ndarray | None = None + patch_lengths: np.ndarray | None = None + ngram_ids: np.ndarray | None = None + is_final: bool = False + + def to_python_dict(self) -> dict: + x = self.x.tolist() + y = self.y.tolist() + if self.mask is None: + mask = None + else: + mask = self.mask.tolist() + if self.patch_lengths is None: + patch_lengths = None + else: + patch_lengths = self.patch_lengths.tolist() + if self.ngram_ids is None: + ngram_ids = None + else: + ngram_ids = self.ngram_ids.tolist() + return { + "x": x, + "y": y, + "mask": mask, + "patch_lengths": patch_lengths, + "ngram_ids": ngram_ids, + "is_final": self.is_final, + } + + @classmethod + def from_python_dict(cls, data: dict) -> "Batch": + x = np.array(data["x"]) + y = np.array(data["y"]) + if data["mask"] is None: + mask = None + else: + mask = np.array(data["mask"]) + if data["patch_lengths"] is None: + patch_lengths = None + else: + patch_lengths = np.array(data["patch_lengths"]) + if data["ngram_ids"] is None: + ngram_ids = None + else: + ngram_ids = np.array(data["ngram_ids"]) + return Batch( + x=x, + y=y, + mask=mask, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + is_final=data["is_final"], + ) diff --git a/bytelatent/data/iterators/__init__.py b/bytelatent/data/iterators/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/data/iterators/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py new file mode 100644 index 0000000..7fb442b --- /dev/null +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import abc +from typing import Any, Generator, Generic, TypeVar + +T = TypeVar("T") +C = TypeVar("C") + + +class StatefulIterator(Generic[T, C], abc.ABC): + + @abc.abstractmethod + def get_state(self) -> C: + pass + + @abc.abstractmethod + def create_iter(self) -> Generator[T, Any, None]: + pass + + +class IteratorState(Generic[C]): + @abc.abstractmethod + def build(self) -> StatefulIterator[T, C]: + pass diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py new file mode 100644 index 0000000..df5f023 --- /dev/null +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import re +from logging import getLogger +from pathlib import Path +from typing import Any, Generator + +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore +from pydantic import BaseModel, ConfigDict + +from bytelatent import ByteLatentError +from bytelatent.data.data_types import BltExample +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator + +logger = getLogger(__name__) + + +class ArrowFileIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + file_path: str | None + row_num: int + num_workers: int + worker_id: int + preprocess_dir: str | None + dataset_files: list[str] | None + entropy_model_name: str | None + arrow_batch_size: int = 100 + + def build(self) -> "ArrowFileIterator": + arrow_file = ArrowFileIterator( + file_path=self.file_path, + worker_id=self.worker_id, + num_workers=self.num_workers, + preprocess_dir=self.preprocess_dir, + entropy_model_name=self.entropy_model_name, + arrow_batch_size=self.arrow_batch_size, + dataset_files=self.dataset_files, + ) + if self.row_num != 0: + arrow_file._set_row_num(self.row_num) + return arrow_file + + +def shard_sort_key(file: str | Path): + match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) + shard_number = int(match.group(1)) + return shard_number + + +class ArrowFileIterator(StatefulIterator): + def __init__( + self, + *, + file_path: str | None, + worker_id: int, + num_workers: int, + preprocess_dir: str | None, + entropy_model_name: str | None, + arrow_batch_size: int, + dataset_files: list[str] | None = None, + ): + assert 0 <= worker_id < num_workers, (worker_id, num_workers) + if file_path is None and dataset_files is None: + raise ByteLatentError("file_path and dataset_files cannot both be None") + self.row_num = 0 + self.iter_id = 0 + self.batch_iterator = None + self.batch_to_consume = None + self.dataset = None + self.file_path = file_path + self.worker_id = worker_id + self.num_workers = num_workers + self.preprocess_dir = preprocess_dir + self.entropy_model_name = entropy_model_name + self.arrow_batch_size = arrow_batch_size + if dataset_files is None: + # Prepare arrow shards + jsonl_file = Path(file_path) + parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + assert parts is not None + dataset = parts.group(1) + data_dir = Path(preprocess_dir) / dataset / entropy_model_name + shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + for s in shard_files: + if not (data_dir / f"{s.name}.complete").exists(): + raise ValueError(f"Missing .complete for input file: {s}") + + shard_files = sorted(shard_files, key=shard_sort_key) + if len(shard_files) == 0: + raise ByteLatentError( + f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" + ) + self.dataset_files = [str(f) for f in shard_files] + else: + self.preprocess_dir = None + self.dataset_files = dataset_files + + def get_state(self) -> ArrowFileIteratorState: + return ArrowFileIteratorState( + file_path=self.file_path, + row_num=self.row_num, + worker_id=self.worker_id, + num_workers=self.num_workers, + preprocess_dir=self.preprocess_dir, + entropy_model_name=self.entropy_model_name, + arrow_batch_size=self.arrow_batch_size, + dataset_files=self.dataset_files, + ) + + def create_iter( + self, + ) -> Generator[BltExample, Any, None]: + if self.dataset is None: + self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + self.batch_iterator = self.dataset.to_batches( + batch_size=self.arrow_batch_size + ) + self.iter_id += 1 + if self.batch_to_consume is not None: + batch_columns: dict[str, list] = self.batch_to_consume + self.batch_to_consume = None + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + for i in range(len(sample_ids)): + out = BltExample( + sample_id=sample_ids[i], + entropies=entropies[i], + text=texts[i], + tokens=None, + mask=None, + patch_lengths=None, + ) + self.row_num += 1 + if (self.row_num - 1) % self.num_workers == self.worker_id: + yield out + + for batch in self.batch_iterator: + batch_columns = batch.to_pydict() + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + for i in range(len(sample_ids)): + out = BltExample( + sample_id=sample_ids[i], + entropies=entropies[i], + text=texts[i], + tokens=None, + mask=None, + patch_lengths=None, + ) + self.row_num += 1 + if (self.row_num - 1) % self.num_workers == self.worker_id: + yield out + + def _set_row_num(self, target_row_num: int): + logger.info( + f"Setting arrow position to {target_row_num} for {self.dataset_files}" + ) + if target_row_num is None or target_row_num == 0: + self.row_num = 0 + self.dataset = None + self.batch_iterator = None + self.batch_to_consume = None + else: + self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + self.batch_iterator = self.dataset.to_batches( + batch_size=self.arrow_batch_size + ) + curr_remaining = target_row_num + for batch in self.batch_iterator: + if len(batch) > curr_remaining: + batch_columns: dict[str, list] = batch.to_pydict() + batch_columns["sample_id"] = batch_columns["sample_id"][ + curr_remaining: + ] + batch_columns["entropies"] = batch_columns["entropies"][ + curr_remaining: + ] + batch_columns["text"] = batch_columns["text"][curr_remaining:] + self.batch_to_consume = batch_columns + break + elif len(batch) == curr_remaining: + # We are exactly at the end of the batch, + # so the next batch is the right spot + break + else: + curr_remaining -= len(batch) + self.row_num = target_row_num + logger.info( + f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + ) + + +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN +): + dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks diff --git a/bytelatent/data/iterators/looping_iterator.py b/bytelatent/data/iterators/looping_iterator.py new file mode 100644 index 0000000..2eff38c --- /dev/null +++ b/bytelatent/data/iterators/looping_iterator.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from pydantic import BaseModel + +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.arrow_iterator import ( + ArrowFileIterator, + ArrowFileIteratorState, +) + + +class LoopingIteratorState(BaseModel, IteratorState): + file_iterator_state: ArrowFileIteratorState + epoch: int + + def build(self) -> "LoopingIterator": + return LoopingIterator( + file_iterator=self.file_iterator_state.build(), + epoch=self.epoch, + ) + + +class LoopingIterator(StatefulIterator): + def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1): + self.file_iterator = file_iterator + self.epoch = epoch + + def get_state(self): + return LoopingIteratorState( + file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch + ) + + def create_iter(self): + while True: + self.epoch += 1 + iterator = self.file_iterator.create_iter() + yield from iterator diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py new file mode 100644 index 0000000..f17ca6e --- /dev/null +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import json +import logging +import multiprocessing as mp +from multiprocessing.synchronize import Event as EventClass +from queue import Empty, Full + +import numpy as np +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.data_types import Batch +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.packing_iterator import PackingIteratorState + +logger = logging.getLogger() + + +class MultiprocessIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + base_iterator_state: PackingIteratorState + n_batches_to_prefetch: int + serialized_prefetch_buffer: str + + def build(self): + base_iterator = self.base_iterator_state.build() + data = json.loads(self.serialized_prefetch_buffer) + prefetch_buffer = [Batch.from_python_dict(item) for item in data] + return MultiprocessIterator( + base_iterator, + n_batches_to_prefetch=self.n_batches_to_prefetch, + prefetch_buffer=prefetch_buffer, + ) + + +def start_work_from_state( + batch_queue: mp.Queue, + state_queue: mp.Queue, + stop_event: EventClass, + state_dumped_event: EventClass, + state: IteratorState, +): + logging.info("Worker thread: Starting base_iterator work") + stateful_iterator = state.build() + iterator = stateful_iterator.create_iter() + for item in iterator: + while not stop_event.is_set(): + try: + # Attempt to put on queue or timeout to try again (maybe main thread is busy) + batch_queue.put(item, timeout=0.1) + # On success, stop trying + break + except Full: + pass + if stop_event.is_set(): + # Signal the end of output, this ensures that even if the queue takes a while to + # buffer, that the main thread receives everything (and tosses this fake batch) + logging.info( + "Worker thread: Stop event detected, outputting is_final=True batch" + ) + batch_queue.put( + Batch( + x=np.zeros((1, 1)), + y=np.zeros((1, 1)), + is_final=True, + mask=None, + patch_lengths=None, + ngram_ids=None, + ) + ) + break + + try: + logging.info("Worker thread: outputting state") + state_queue.put(iterator.get_state(), timeout=1) + logging.info("Worker thread: state dump complete") + state_dumped_event.set() + logging.info("Worker thread: set state_dump_event") + except Full: + raise ValueError( + "Attempted to dump state into the state queue, but it was full" + ) + + +class MultiprocessIterator(StatefulIterator): + """ + Design sketch of the multiprocess iterator: + + Given the base_iterator, the only thing we do with this is call get_state() + so that we can pass that through to the background worker process. + + The background process will receive this, rebuild the iterator, then start yielding from it. + + However, in order to implement MultiprocessIterator.get_state(), we need to be able to accurately get + (1) the state of the iterator in the worker process + (2) the currently buffered items in the Queue + + To do this, we use: + - batch_queue: This is the prefetch buffer the worker yields to and the main loop yields from + - state_queue: This size 1 queue will be how the worker sends the iterator state once it has halted iterating. + It must hold the state in addition to the last batch, if the queue was full at the time the stop event is sent. + - stop_iterating_event: Once this is issued from the main loop, the worker will stop iterating and enter cleanup. + During cleanup, the iterator will send the state of the current iterator to the main loop, + in addition to possibly the last batch if the batch_queue was full at the time + - state_dumped_event: When the main loop issues the stop_iterating_event, it will wait until the state_dumped_event to attempt + to get state from the state_queue. It must do this since the worker may take some time to create and send the state. + Once received by the main loop, the main loop can safely store the Queue (plus maybe the last batch) as the prefetch buffer, + get the worker iterator's state, and terminate the background process + delete associated objects. + + At this point, calling create_iter() again will bootstrap everything from the stored state and the old iterator will throw an error + since it will not iterate anymore (so the caller must call create_iter() again to get a python iterator). + + """ + + def __init__( + self, + base_iterator: StatefulIterator, + *, + n_batches_to_prefetch: int, + prefetch_buffer: list | None = None + ): + self.base_iterator = base_iterator + self.n_batches_to_prefetch = n_batches_to_prefetch + if prefetch_buffer is None: + prefetch_buffer = [] + self.prefetch_buffer = prefetch_buffer + self.batch_queue = None + self.state_queue = None + self.producer = None + self.stop_iterating_event = None + self.state_dumped_event = None + + def get_state(self) -> MultiprocessIteratorState: + """ + This is slightly unusual in effectively destroying the current iterator, its necessary + to halt the background process and allow it to write the state to the main loop + in order to not lose data + """ + if self.producer is None: + serialized_prefetch_buffer = json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ) + return MultiprocessIteratorState( + base_iterator_state=self.base_iterator.get_state(), + n_batches_to_prefetch=self.n_batches_to_prefetch, + serialized_prefetch_buffer=serialized_prefetch_buffer, + ) + else: + logging.info("Main thread: Sending stop iteration event") + self.stop_iterating_event.set() + logging.info("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + self.prefetch_buffer = [] + final_batch_received = False + while True: + try: + batch = self.batch_queue.get(timeout=1) + if batch.is_final: + final_batch_received = True + break + self.prefetch_buffer.append(batch) + except Empty: + logging.warning("Main thread: batch_queue is abnormally empty") + assert final_batch_received + + try: + base_iterator_state = self.state_queue.get(timeout=1) + assert isinstance(base_iterator_state, IteratorState) + except Empty: + raise ValueError( + "Attempted to get the state, but it was unexpectantly missing" + ) + + self.base_iterator = base_iterator_state.build() + self.producer.close() + self.producer = None + self.batch_queue = None + self.state_queue = None + self.stop_iterating_event = None + self.state_dumped_event = None + + return MultiprocessIteratorState( + base_iterator_state=self.base_iterator.get_state(), + n_batches_to_prefetch=self.n_batches_to_prefetch, + serialized_prefetch_buffer=json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ), + ) + + def create_iter(self): + logging.info("Main thread: Creating MP iterator") + # First yield from the stored prefetch buffer. + if self.prefetch_buffer is not None: + while len(self.prefetch_buffer) > 0: + item = self.prefetch_buffer.pop(0) + yield item + self.prefetch_buffer = None + + assert ( + self.producer is None + ), "Cannot create two parallel iterators at once, call get_state() then remake to have two." + + # using mp context manager avoids excessive CPU loading + ctx = mp.get_context("forkserver") + self.batch_queue = ctx.Manager().Queue(maxsize=self.n_batches_to_prefetch) + + # We should only ever one state, which is output at the detection of a stop event + self.state_queue = ctx.Manager().Queue(maxsize=1) + + self.stop_iterating_event = ctx.Event() + self.state_dumped_event = ctx.Event() + + self.producer = mp.Process( + name="blt_data_loader", + target=start_work_from_state, + args=( + self.batch_queue, + self.state_queue, + self.stop_iterating_event, + self.state_dumped_event, + self.base_iterator.get_state(), + ), + ) + logger.info("Async dataloader started") + self.producer.start() + + while True: + if self.producer.exitcode is not None: + raise RuntimeError( + "Data loader quit unexpectedly, real error has been raised previously" + ) + try: + batch = self.batch_queue.get(timeout=0.1) + assert isinstance(batch, Batch) + assert ( + not batch.is_final + ), "is_final should only be used during get_state() being called" + yield batch + except Empty: + pass + if self.producer is None: + raise ValueError( + "Attempted to call this iterator after calling get_state(). You must call create_iter() to make a new iterator instead." + ) diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py new file mode 100644 index 0000000..361fc03 --- /dev/null +++ b/bytelatent/data/iterators/packing_iterator.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.data_types import Batch, BltSequence +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState + + +class PackingArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + batch_size: int + seq_len: int + pad_id: int + max_length: int | None + pad_to_max_length: bool + enable_byte_ngrams: bool + + +class PackingIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + sequence_iterator_state: SamplingIteratorState + packing_args: PackingArgs + + def build(self) -> "PackingIterator": + return PackingIterator( + sequence_iterator=self.sequence_iterator_state.build(), + packing_args=self.packing_args, + ) + + +def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]): + assert len(mask_seqs) == bs + lens = [len(m) for m in mask_seqs] + if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens): + return None + assert slen == max(lens) - 1 + mask = np.zeros((bs, slen), dtype=bool) + for i, m in enumerate(mask_seqs): + if m is None: + print( + "Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function." + ) + raise NotImplementedError + mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:] + return mask + + +def truncate_batch( + batch: Batch, + max_length: int, + pad_id: int, + pad_to_max_length: bool = False, + *, + enable_byte_ngrams: bool, +): + """ + Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts + and fixing the batch.mask. + + batch.patch_lengths has unchanged shape + x,y, and mask may reduce in size + """ + if batch.patch_lengths is None: + return batch + + seq_lengths = batch.patch_lengths.sum(axis=1) + max_length_adj = max_length + 1 + if np.any(seq_lengths > max_length_adj): + for i in range(batch.x.shape[0]): + if seq_lengths[i] > max_length_adj: + # Find id of patch that tips over max_length + 1 + count, j = 0, 0 + while count + batch.patch_lengths[i, j] <= max_length_adj: + count += batch.patch_lengths[i, j] + j += 1 + # Edit the batch + assert j < batch.patch_lengths.shape[1] + batch.x[i, max_length:] = pad_id + batch.y[i, max_length:] = pad_id + if batch.mask is not None: + batch.mask[i, max_length:] = False + batch.patch_lengths[i, j:] = 0 + batch.patch_lengths[i, j] = max_length_adj - count + + # Truncate if necessary. + if max_length < batch.x.shape[1]: + batch.x = batch.x[:, :max_length] + batch.y = batch.y[:, :max_length] + if batch.mask is not None: + batch.mask = batch.mask[:, :max_length] + + # Right pad to max_length if necessary + elif pad_to_max_length: + if batch.x.shape[1] < max_length: + # NOTE: this has to be done on an actual patch. + non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1 + non_zero_indices = np.maximum(0, non_zero_indices) + batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += ( + max_length - batch.x.shape[1] + ) + # TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader. + x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype) + x[:, : batch.x.shape[1]] = batch.x + batch.x = x + if batch.y.shape[1] < max_length: + y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype) + y[:, : batch.y.shape[1]] = batch.y + batch.y = y + if batch.mask is not None and batch.mask.shape[1] < max_length: + mask = np.full( + (batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype + ) + mask[:, : batch.mask.shape[1]] = batch.mask + batch.mask = mask + + assert batch.x.shape[1] <= max_length + assert batch.y.shape[1] <= max_length + assert batch.mask is None or batch.mask.shape[1] <= max_length + assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0) + if pad_to_max_length: + assert batch.x.shape[1] == max_length + assert batch.y.shape[1] == max_length + assert batch.mask is None or batch.mask.shape[1] == max_length + if enable_byte_ngrams: + raise NotImplementedError() + # (num_ngram, batch_size, seq_len) + ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x)) + assert ngram_ids.shape[2] == batch.x.shape[1] + else: + ngram_ids = None + batch.ngram_ids = ngram_ids + + +class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): + def __init__( + self, + sequence_iterator: StatefulIterator[BltSequence, Any], + *, + packing_args: PackingArgs, + ): + self.sequence_iterator = sequence_iterator + self.packing_args = packing_args + + def get_state(self): + return PackingIteratorState( + sequence_iterator_state=self.sequence_iterator.get_state(), + packing_args=self.packing_args, + ) + + def create_iter(self): + sequence_iter = self.sequence_iterator.create_iter() + batch_size = self.packing_args.batch_size + pad_id = self.packing_args.pad_id + seq_len = self.packing_args.seq_len + pad_to_max_length = self.packing_args.pad_to_max_length + enable_byte_ngrams = self.packing_args.enable_byte_ngrams + max_length = self.packing_args.max_length + while True: + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + patch_lengths: list[list[int]] = [] + + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + _patch_lengths = sequence.patch_lengths + assert len(sequence.patch_lengths) == self.packing_args.seq_len + last_patch_length = 0 + if _patch_lengths[0] > 1: + last_patch_length = _patch_lengths[-1] + _patch_lengths[0] -= 1 + _patch_lengths = [1] + _patch_lengths[:-1] + tokens.append(_tokens[: len(_tokens) - last_patch_length]) + masks.append(_mask[: len(_mask) - last_patch_length]) + patch_lengths.append(_patch_lengths) + + x_patch_lengths = np.array(patch_lengths) + # pad batch to same length + tok_seq_len = max([len(toks) for toks in tokens]) - 1 + x = np.full((batch_size, tok_seq_len), fill_value=pad_id) + y = np.full((batch_size, tok_seq_len), fill_value=pad_id) + + for i, tok_seq in enumerate(tokens): + x[i, : len(tok_seq) - 1] = tok_seq[:-1] + y[i, : len(tok_seq) - 1] = tok_seq[1:] + # Adjust patch lengths to match x + x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1) + + assert x_patch_lengths.shape == (batch_size, seq_len) + + if enable_byte_ngrams: + raise NotImplementedError() + else: + ngram_ids = None + + batch = Batch( + x=x, + y=y, + patch_lengths=x_patch_lengths, + ngram_ids=ngram_ids, + mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), + ) + assert ( + x_patch_lengths.sum() == x.size + batch_size + ), f"{x_patch_lengths.sum()} != {x.size + batch_size}" + assert ( + batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() + ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + assert np.all( + x_patch_lengths[:, 0] == 1 + ), f"first patch should always be 1, {x_patch_lengths[:, 0]}" + # cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024) + # cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024 + # print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}") + truncate_batch( + batch, + max_length=max_length, + pad_id=pad_id, + pad_to_max_length=pad_to_max_length, + enable_byte_ngrams=enable_byte_ngrams, + ) + yield batch diff --git a/bytelatent/data/iterators/preprocess_iterator.py b/bytelatent/data/iterators/preprocess_iterator.py new file mode 100644 index 0000000..8eeba41 --- /dev/null +++ b/bytelatent/data/iterators/preprocess_iterator.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Any, Generator + +import torch +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.data_types import BltExample +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.arrow_iterator import ( + ArrowFileIterator, + ArrowFileIteratorState, +) +from bytelatent.data.iterators.looping_iterator import LoopingIteratorState +from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +class PreprocessIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState + add_tokens: bool + add_patches: bool + tokenizer_args: TokenizerArgs + patcher_args: PatcherArgs + + def build(self): + arrow_iterator = self.arrow_file_iterator_state.build() + return PreprocessIterator( + arrow_iterator, + patcher_args=self.patcher_args, + tokenizer_args=self.tokenizer_args, + add_tokens=self.add_tokens, + add_patches=self.add_patches, + ) + + +class PreprocessIterator(StatefulIterator): + """ + Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require + preprocessing like tokenization and patching + """ + + def __init__( + self, + arrow_iterator: ArrowFileIterator, + *, + patcher_args: PatcherArgs, + tokenizer_args: TokenizerArgs, + add_tokens: bool = True, + add_patches: bool = True, + ): + self.arrow_iterator = arrow_iterator + self.tokenizer_args = tokenizer_args + self.patcher_args = patcher_args + self.add_tokens = add_tokens + self.add_patches = add_patches + self.tokenizer: BltTokenizer | None = None + self.patcher: Patcher | None = None + + def get_state(self) -> PreprocessIteratorState: + """ + The only state to maintain here is from arrow, there + isn't any internal state on this iterator. + """ + return PreprocessIteratorState( + arrow_file_iterator_state=self.arrow_iterator.get_state(), + tokenizer_args=self.tokenizer_args, + patcher_args=self.patcher_args, + add_tokens=self.add_tokens, + add_patches=self.add_patches, + ) + + def create_iter(self) -> Generator[BltExample, Any, None]: + if self.tokenizer is None and self.add_tokens: + self.tokenizer = self.tokenizer_args.build() + if self.patcher is None and self.add_patches: + self.patcher = self.patcher_args.build() + + example_iter = self.arrow_iterator.create_iter() + for example in example_iter: + if self.add_tokens: + tokens = self.tokenizer.encode(example.text) + else: + tokens = example.tokens + if ( + self.patcher is not None + and self.patcher.patching_mode == PatchingModeEnum.entropy + ): + assert ( + example.entropies is not None + ), "For patching, entropies cannot be None" + entropies = torch.tensor(example.entropies).unsqueeze(0) + else: + entropies = None + if self.patcher is None: + patch_lengths = None + else: + patch_lengths = self.patcher.patch( + torch.tensor(tokens).unsqueeze(0), + include_next_token=False, + entropies=entropies, + )[0][0].tolist() + yield BltExample( + sample_id=example.sample_id, + text=example.text, + tokens=tokens, + mask=[True] * len(tokens), + patch_lengths=patch_lengths, + entropies=example.entropies, + ) diff --git a/bytelatent/data/iterators/sampling_iterator.py b/bytelatent/data/iterators/sampling_iterator.py new file mode 100644 index 0000000..6474bf6 --- /dev/null +++ b/bytelatent/data/iterators/sampling_iterator.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState + + +class SamplingIteratorState(BaseModel): + model_config = ConfigDict(extra="forbid") + rng_state: dict[str, Any] + source_to_weight: dict[str, float] + source_to_iterator_state: dict[str, SequenceIteratorState] + + def build(self) -> "SamplingIterator": + return SamplingIterator( + rng_state=self.rng_state, + source_to_weight=self.source_to_weight, + source_to_iterator={ + source: state.build() + for source, state in self.source_to_iterator_state.items() + }, + ) + + +class SamplingIterator(StatefulIterator): + def __init__( + self, + *, + rng_state: dict[str, Any], + source_to_weight: dict[str, float], + source_to_iterator: dict[str, StatefulIterator], + ): + self.rng = np.random.default_rng() + self.rng.bit_generator.state = rng_state + self.source_to_weight = source_to_weight + self.source_to_iterator = source_to_iterator + + def get_state(self) -> SamplingIteratorState: + return SamplingIteratorState( + rng_state=self.rng.bit_generator.state, + source_to_weight=self.source_to_weight, + source_to_iterator_state={ + source: iterator.get_state() + for source, iterator in self.source_to_iterator.items() + }, + ) + + def create_iter(self): + n_sources = len(self.source_to_weight) + possible_sources = [] + weights = [] + for source, w in self.source_to_weight.items(): + possible_sources.append(source) + weights.append(w) + + source_to_python_iter = { + source: self.source_to_iterator[source].create_iter() + for source in possible_sources + } + while True: + norm_weights = np.array(weights) / np.array(weights).sum() + source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)] + yield next(source_to_python_iter[source_choice]) diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py new file mode 100644 index 0000000..14e3747 --- /dev/null +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from logging import getLogger +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.data_types import BltSequence +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.preprocess_iterator import ( + PreprocessIterator, + PreprocessIteratorState, +) + +logger = getLogger() + + +class SequencePackingArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + output_seq_len: int + buffer_size: int + + +class SequenceIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + sequence_packing_args: SequencePackingArgs + preprocess_iterator_state: PreprocessIteratorState + rng_state: dict[str, Any] + + def build(self): + preprocess_iterator = self.preprocess_iterator_state.build() + return SequenceIterator( + preprocess_iterator, + sequence_packing_args=self.sequence_packing_args, + rng_state=self.rng_state, + ) + + +class SequenceIterator(StatefulIterator): + def __init__( + self, + preprocess_iterator: PreprocessIterator, + *, + rng_state: dict[str, Any], + sequence_packing_args: SequencePackingArgs, + ): + self.preprocess_iterator = preprocess_iterator + self.sequence_packing_args = sequence_packing_args + self.output_seq_len = sequence_packing_args.output_seq_len + self.buffer_size = sequence_packing_args.buffer_size + self.rng = np.random.default_rng() + self.rng.bit_generator.state = rng_state + + def get_state(self): + # TODO: need to also perist the current shuffle buffer + return SequenceIteratorState( + sequence_packing_args=self.sequence_packing_args, + preprocess_iterator_state=self.preprocess_iterator.get_state(), + rng_state=self.rng.bit_generator.state, + ) + + def create_iter(self): + example_iter = self.preprocess_iterator.create_iter() + n_buffer_patches = self.buffer_size * self.output_seq_len + + patch_lengths: list[int] = [] + tokens: list[int] = [] + mask: list[bool] = [] + first = True + for example in example_iter: + assert example.tokens is not None + assert example.mask is not None + assert example.patch_lengths is not None + assert len(example.tokens) != 0 + assert len(example.mask) != 0 + assert len(example.tokens) == len(example.mask) + assert len(example.tokens) == sum(example.patch_lengths) + + tokens.extend(example.tokens) + mask.extend(example.mask) + patch_lengths.extend(example.patch_lengths) + + while len(patch_lengths) >= n_buffer_patches: + if first: + first = False + logger.info("First buffer complete") + + x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape( + self.buffer_size, self.output_seq_len + ) + seq_tokens = [] + seq_mask = [] + start_id = 0 + # We fix the number of patches and therefore global steps per batch + # so we have a variable number of tokens we need to account for + for num_tokens in x_patches.sum(axis=-1): + seq_tokens.append(tokens[start_id : start_id + num_tokens]) + seq_mask.append(mask[start_id : start_id + num_tokens]) + start_id += num_tokens + + assert start_id == x_patches.sum() + + # Remove what we just added from the buffer + patch_lengths = patch_lengths[n_buffer_patches:] + tokens = tokens[x_patches.sum() :] + mask = mask[x_patches.sum() :] + + seq_patch_lengths: list[list[int]] = x_patches.tolist() + assert len(seq_patch_lengths) == self.buffer_size + for idx in self.rng.permutation(len(seq_patch_lengths)): + assert len(seq_patch_lengths[idx]) == self.output_seq_len + assert ( + sum(seq_patch_lengths[idx]) + == len(seq_tokens[idx]) + == len(seq_mask[idx]) + ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" + assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=seq_patch_lengths[idx], + ) diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py new file mode 100644 index 0000000..4266427 --- /dev/null +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import numpy as np +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + +from bytelatent.constants import BLT_DATA +from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState + +ENTROPY_MODEL = "transformer_100m" +ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") +ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow") + + +def test_basic_arrow_file(): + dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow") + n_head = 1000 + head_df = dataset.head(n_head).to_pandas() + + initial_state = ArrowFileIteratorState( + file_path=None, + num_workers=1, + worker_id=0, + preprocess_dir=None, + entropy_model_name=ENTROPY_MODEL, + dataset_files=[ARROW_TEST_DATA_1], + row_num=0, + arrow_batch_size=100, + ) + arrow_file = initial_state.build() + start_state = arrow_file.get_state() + assert start_state.row_num == initial_state.row_num + + sample_id = None + for example in arrow_file.create_iter(): + sample_id = example.sample_id + assert head_df.iloc[0]["sample_id"] == sample_id + break + + assert arrow_file.get_state().row_num == 1 + arrow_file = initial_state.build() + for example in arrow_file.create_iter(): + assert example.sample_id == sample_id + assert head_df.iloc[0]["sample_id"] == sample_id + break + + # Test resume far enough in to be past the batch size of 100 + resumed_state = ArrowFileIteratorState( + file_path=None, + num_workers=1, + worker_id=0, + preprocess_dir=None, + entropy_model_name=ENTROPY_MODEL, + dataset_files=[ARROW_TEST_DATA_1], + row_num=251, + arrow_batch_size=100, + ) + arrow_file = resumed_state.build() + for example in arrow_file.create_iter(): + assert example.sample_id == head_df.iloc[251]["sample_id"] + assert arrow_file.get_state().row_num == 252 + break + + world_rank = 1 + world_size = 4 + # Test World Size and Rank + rank_state = ArrowFileIteratorState( + file_path=None, + num_workers=world_size, + worker_id=world_rank, + preprocess_dir=None, + entropy_model_name=ENTROPY_MODEL, + dataset_files=[ARROW_TEST_DATA_1], + row_num=0, + arrow_batch_size=100, + ) + arrow_file = rank_state.build() + expected_ids = [] + for i in range(n_head): + if i % world_size == world_rank: + expected_ids.append(head_df.iloc[i]["sample_id"]) + print(len(expected_ids)) + i = 0 + for example in arrow_file.create_iter(): + assert example.sample_id == expected_ids[i] + i += 1 + if i >= len(expected_ids): + break diff --git a/bytelatent/data/iterators/test_iters.py b/bytelatent/data/iterators/test_iters.py new file mode 100644 index 0000000..9bc9d59 --- /dev/null +++ b/bytelatent/data/iterators/test_iters.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import pandas as pd +from pydantic import BaseModel + +from bytelatent.constants import BLT_DATA +from bytelatent.data.data_types import BltExample +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator +from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +class BltTestIteratorState(BaseModel, IteratorState): + position: int + total: int + + def build(self): + blt_iter = BltTestIteratorState(total=self.total) + blt_iter.position = self.position + return blt_iter + + +class BltTestIterator(StatefulIterator): + def __init__(self, total: int): + self.position = 0 + self.total = total + + def get_state(self): + return BltTestIteratorState(position=self.position, total=self.total) + + def create_iter(self): + for i in range(self.total): + self.position += 1 + yield BltExample( + sample_id=f"test_{i}", + text=f"This is some test {i} text.", + tokens=None, + mask=None, + entropies=None, + patch_lengths=None, + ) + + +class BltTestWithEntropiesIteratorState(BaseModel, IteratorState): + position: int + total: int + + def build(self): + blt_iter = BltTestWithEntropiesIteratorState(total=self.total) + blt_iter.position = self.position + return blt_iter + + +class BltTestWithEntropiesIterator(StatefulIterator): + def __init__(self, total: int): + self.position = 0 + self.total = total + + def get_state(self): + return BltTestIteratorState(position=self.position, total=self.total) + + def create_iter(self): + text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin." + df = pd.read_json("fixtures/tokens_with_entropies.json") + tokens = df["token_ids"].tolist() + entropies = df["entropies"].tolist() + # BOS and EOS + assert len(tokens) == len(text) + 2 + for i in range(self.total): + self.position += 1 + yield BltExample( + sample_id=f"test_{i}", + text=text, + tokens=tokens, + mask=[True] * len(tokens), + entropies=entropies, + patch_lengths=None, + ) + + +def test_preprocess_iter(): + total = 3 + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + for mode in [ + PatchingModeEnum.bpe, + PatchingModeEnum.space, + ]: + data_it = BltTestIterator(total) + patcher_args = PatcherArgs(patching_mode=mode) + example_it = PreprocessIterator( + data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args + ) + count = 0 + for example in example_it.create_iter(): + assert isinstance(example.tokens, list) + assert isinstance(example.tokens[0], int) + # BOS and EOS + assert len(example.tokens) == len(example.text) + 2 + assert example.mask is not None + assert len(example.tokens) == len(example.mask) + count += 1 + + assert count == total + + +def test_non_entropy_patch_iter(): + total = 3 + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + for mode in [ + PatchingModeEnum.bpe, + PatchingModeEnum.space, + ]: + patcher_args = PatcherArgs(patching_mode=mode) + data_it = BltTestIterator(total) + example_it = PreprocessIterator( + data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args + ) + + count = 0 + for example in example_it.create_iter(): + assert isinstance(example.patch_lengths, list) + assert isinstance(example.patch_lengths[0], int) + assert len(example.tokens) == sum(example.patch_lengths) + count += 1 + + assert count == total + + +def test_entropy_patch_iter(): + total = 2 + patcher_args = PatcherArgs( + patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627 + ) + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + data_it = BltTestWithEntropiesIterator(total) + example_it = PreprocessIterator( + data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args + ) + + count = 0 + for example in example_it.create_iter(): + assert isinstance(example.patch_lengths, list) + assert isinstance(example.patch_lengths[0], int) + assert len(example.tokens) == sum(example.patch_lengths) + count += 1 + + assert count == total diff --git a/bytelatent/data/ngram_processor.py b/bytelatent/data/ngram_processor.py new file mode 100644 index 0000000..2498183 --- /dev/null +++ b/bytelatent/data/ngram_processor.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import pickle +from pathlib import Path + +import numpy as np + +from bytelatent import ByteLatentError + +LOOKUP_OFFSET = 4 + + +def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1): + """ + Wrapper function for applying the lookup table to each n-gram. + + :param ngram: Array of numbers representing an n-gram. + :param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs. + :param lookup_offset: Offset to add to the lookup result. + :return: The value associated with the n-gram tuple in the dictionary, or None if not found. + """ + + def apply_lookup_table(ngram): + """ + Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary. + + :param ngram: Array of numbers representing an n-gram. + :return: The value associated with the n-gram tuple in the dictionary, or None if not found. + """ + # Convert the n-gram to a tuple + ngram_tuple = tuple(ngram) + + if ngram_tuple not in ngram_to_idx: + return 0 + else: + return ngram_to_idx[ngram_tuple] + lookup_offset + + return apply_lookup_table + + +def get_byte_ngrams_ids( + byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0 +): + """ + Generate n-grams from a 2D numpy array. + + :param n: The length of each n-gram. + :param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams. + :return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET. + """ + num_rows, num_cols = byte_array.shape + + # Create an array to hold the padded version of the original array + padded_array = np.pad( + byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value + ) + + # Use stride tricks to avoid explicit looping + strided = np.lib.stride_tricks.as_strided + shape = (num_rows, num_cols, n) + strides = padded_array.strides[:2] + (padded_array.strides[1],) + ngrams = strided(padded_array, shape=shape, strides=strides) + + ngram_ids = np.apply_along_axis( + apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams + ) + assert ngram_ids.shape == byte_array.shape + return ngram_ids + + +def reload_tables( + ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET +) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]: + """ + Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram, + only load up to the max specified size. Return the actual number of ngrams taken per ngram size. + """ + idx_to_ngram_tables = {} + ngram_to_idx_tables = {} + vocab_sizes = {} + for ngram, size in ngram_to_size.items(): + with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f: + # These are already sorted by count + # Value: tuple of: count, ngram, dataset + ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[ + "counts" + ] + table = [ngram for ngram, _ in ngram_data][:size] + if len(table) != size: + raise ValueError( + f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}" + ) + ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)} + actual_size = len(table) + idx_to_ngram_tables[ngram] = table + ngram_to_idx_tables[ngram] = ngram_to_idx + vocab_sizes[ngram] = actual_size + offset + return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes + + +def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: + if ngram_to_size_str is None: + return None + ngram_to_size = {} + for entry in ngram_to_size_str.split(","): + ngram, size = entry.split(":") + ngram = int(ngram) + size = int(size) + ngram_to_size[ngram] = size + return ngram_to_size + + +class NgramProcessor: + def __init__( + self, + ngram_table_dir: str | None = None, + ngram_to_size: dict[int, int] | None = None, + ): + if ngram_table_dir is None or ngram_to_size is None: + raise ByteLatentError( + "ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True" + ) + ( + self.ngram_to_idx_tables, + self.idx_to_ngram_tables, + self.ngram_vocab_sizes, + ) = reload_tables(ngram_table_dir, ngram_to_size) + # Lowest to highest ngram + self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys())) + # Although the model might not use all the ngrams, we need the tokenizer + # to produce ngram_ids such that index zero is the 2-gram, later on in + # src.model.megabyte.Megabyte.forward + assert self.ngram_sizes[0] == 2 + + def encode_single_ngram_table(self, data: np.ndarray, n: int): + """ + Return the n-grams of the input data for a given n + numpy array with ids of shape data.shape + """ + return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0) + + def encode_token_ngrams(self, data: np.ndarray): + """ + Return the n-grams of the input data. + output shape: [ids with data.shape for n in self.ngram_sizes] + """ + return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes] diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py new file mode 100644 index 0000000..ede8b06 --- /dev/null +++ b/bytelatent/data/patcher.py @@ -0,0 +1,609 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import math +import time +from collections import defaultdict +from enum import Enum + +import torch +from pydantic import BaseModel +from torch.nn import functional as F + +from bytelatent.distributed import get_local_rank +from bytelatent.entropy_model import load_entropy_model + +# from src.slurm import get_local_rank +from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET +from bytelatent.tokenizers.constants import BPE_ID, OFFSET + + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + + +class PatcherArgs(BaseModel): + patching_mode: PatchingModeEnum = PatchingModeEnum.entropy + patching_device: str = "cuda" + entropy_model_checkpoint_dir: str | None = None + realtime_patching: bool = False + threshold: float = 1.335442066192627 + threshold_add: float | None = None + max_patch_length: int | None = None + patch_size: float = 4.5 + patching_batch_size: int = 1 + data_loader_patching: bool = False + device: str = "cuda" + monotonicity: bool = False + log_time: bool = False + + def build(self) -> "Patcher": + return Patcher(self) + + +def entropy(scores): + """ + scores: [bs, seq_len, vocab] + returns [bs, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + +def calculate_entropies( + tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None +): + """ + tokens: 2D tensor of shape [batch_size, seq_len] + Return 2D tensor of shape [batch_size, seq_len] with entropies for each token. + + Splits the tokens into chunks of size max_length and calculates entropies for each chunk. + Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. + """ + with torch.no_grad(): + entropies = [] + max_length = getattr(entropy_model, "max_length", 8192) + batch_numel = max_length * patching_batch_size + splits = torch.split(tokens.flatten(), batch_numel) + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros( + pad_size, dtype=split.dtype, device=split.device, requires_grad=False + ) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + assert torch.all(split >= 0) and torch.all(split < 260) + pred, _ = entropy_model(split) + pred = pred.reshape(-1, pred.shape[-1])[ + : split.numel() - pad_size, : + ] # [batch_size * seq_len, vocab] + pred_entropies = entropy(pred) + entropies.append(pred_entropies) + + entropies = torch.cat(entropies, dim=0) + entropies = entropies.reshape(tokens.shape) + return entropies + + +def patch_start_mask_from_entropy_with_monotonicity(entropies, t): + """ + entropies: [bs, seq_len] torch tensor of entropies + t: threshold + returns [bs, seq_len] mask where True indicates the start of a patch + """ + bs, seq_len = entropies.shape + mask = torch.zeros_like(entropies, dtype=torch.bool) + mask[:, 0] = True + + # Calculate differences between consecutive elements along the sequence length + differences = entropies[:, 1:] - entropies[:, :-1] + + # Calculate conditions for all elements except the first one in each sequence + condition = differences > t + + # Update the mask based on the condition + mask[:, 1:] = condition + + return mask + + +def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0): + """ + entropies: [bs, seq_len] torch tensor of entropies + t: threshold + returns [bs, seq_len] mask where True indicates the start of a patch + """ + bs, seq_len = entropies.shape + mask = torch.zeros_like(entropies, dtype=torch.bool) + mask[:, 0] = True + + # Calculate differences between consecutive elements along the sequence length + differences = entropies[:, 1:] - entropies[:, :-1] + + # Calculate conditions for all elements except the first one in each sequence + condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1]) + + # Update the mask based on the condition + mask[:, 1:] = condition + + return mask + + +def patch_start_ids_from_patch_start_mask(patch_start_mask): + bs, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = ( + torch.arange(trunc_seq_len, device=patch_start_mask.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + extra_patch_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat( + (patch_start_mask, ~patch_start_mask), dim=1 + ) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( + bs, trunc_seq_len + )[:, :max_patches] + return patch_start_ids + + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + + +def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + + +def find_space_patch_start_ids(tokens): + bs, seq_len = tokens.shape + tokens_no_offset = tokens - OFFSET + patch_end_mask = ( + (tokens_no_offset < ord("0")) + | ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A"))) + | ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a"))) + | ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000)) + | (0b1100_0000 <= tokens_no_offset) + ) + patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not() + patch_end_mask |= tokens < OFFSET + + patch_start_mask = torch.cat( + [ + torch.tensor([1, 1], device=tokens.device, dtype=torch.bool) + .unsqueeze(0) + .repeat(bs, 1), + patch_end_mask[:, 1:], + ], + dim=1, + ) + max_patches = patch_start_mask.sum(dim=1).max() + + patch_ids = ( + torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1) + ) + extra_patch_ids = torch.full( + (bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) + + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[ + :, :max_patches + ] + return patch_start_ids + + +def to_device(entropy_model, device=None): + if device == "cuda": + rank = get_local_rank() + device = f"cuda:{rank}" + entropy_model = entropy_model.to(device) + return entropy_model, device + + +def model_pred_to_bpe_patching_pred(pred): + _, indices = torch.max(pred, dim=1) + return indices == BPE_ID + + +def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None): + assert tokens.device == torch.device( + "cpu" + ), f"{tokens.device} != cpu expects tokens to be on cpu" + with torch.no_grad(): + bpe_patcher_device, device = to_device( + bpe_patcher, device + ) # Get entropy model to right rank device. + bpe_patching_mask = [] + max_length = getattr(bpe_patcher, "max_length", 8192) + batch_numel = max_length * patching_batch_size + splits = torch.split(tokens.flatten(), batch_numel) + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros( + pad_size, dtype=split.dtype, device=split.device, requires_grad=False + ) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length).to(device) + assert torch.all(split >= 0) and torch.all(split < 260) + pred = bpe_patcher_device(split) + pred_cpu = pred[0].cpu() + pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[ + : split.numel() - pad_size, : + ] # [batch_size * seq_len, vocab] + bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu) + bpe_patching_mask.append(bpe_patching_pred) + bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0) + bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape) + return bpe_patching_mask + + +def find_bpe_patcher_patch_start_ids( + tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True +): + bs, seq_len = tokens.shape + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=tokens.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[1] + token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1] + if token_input.shape[1] >= 1: + patch_start_mask = apply_bpe_patcher( + token_input, bpe_patcher, patching_batch_size, device + ) + assert ( + patch_start_mask.shape[1] + == tokens.shape[1] + include_next_token - preds_truncation_len + ), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}" + patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + else: + patch_start_ids = first_ids + return patch_start_ids + + +def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + threshold_add=None, + monotonicity=False, + include_next_token=True, +): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + bs, seq_len = entropies.shape[:2] + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[ + 1 + ] # remove the first preds because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = seq_len // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + # Assumes that there is at least one token going over the threshold + if monotonicity: + patch_start_mask = patch_start_mask_from_entropy_with_monotonicity( + entropies, threshold + ) + elif threshold_add is not None and threshold is not None: + patch_start_mask = patch_start_mask_global_and_monotonicity( + entropies, threshold, threshold_add + ) + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + return patch_start_ids + + +def rightpad(seq, pad_id, max_len): + return seq + [pad_id] * (max_len - len(seq)) + + +def find_bpe_delim_patch_start_ids(tokens, delim): + ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False) + out = [[0, 1] for _ in range(tokens.shape[0])] + for x, y in ids: + # start is at delim + 1, delim should be the last element in the patch. + out[x.item()].append(y.item() + 1) + max_len = max([len(elt) for elt in out]) + out = [rightpad(elt, tokens.shape[1], max_len) for elt in out] + patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device) + return patch_start_ids + + +def find_lookup_table_start_mask( + tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True +): + window_size = lookup_table.ndim + # Unfold the tensor to get sliding windows + unfolded = tokens.unfold(1, window_size, 1) + # Gather indices for each dimension + indices = [unfolded[..., i] for i in range(window_size)] + # Access the lookup table using the gathered indices + result = lookup_table[indices] + return result + + +def find_lookup_table_patch_start_ids( + tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True +): + bs, seq_len = tokens.shape + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=tokens.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[1] + window_size = lookup_table.ndim + assert window_size == 2, f"{window_size} != 2" + # output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape + token_input = ( + tokens if include_next_token else tokens[:, : -preds_truncation_len + 1] + ) + if token_input.shape[1] >= window_size: + patch_start_mask = find_lookup_table_start_mask( + token_input, lookup_table, include_next_token + ) + assert ( + patch_start_mask.shape[1] + == tokens.shape[1] + include_next_token - preds_truncation_len + ), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}" + patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + else: + patch_start_ids = first_ids + return patch_start_ids + + +def split_large_numbers(lst, m): + new_lst = [] + for i in lst: + if i > m: + while i > m: + new_lst.append(m) + i -= m + new_lst.append(i) + else: + new_lst.append(i) + assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" + return new_lst + + +class Patcher: + def __init__(self, patcher_args: PatcherArgs): + self.patcher_args = patcher_args + self.patching_mode = patcher_args.patching_mode + self.realtime_patching = patcher_args.realtime_patching + if self.realtime_patching: + assert ( + patcher_args.entropy_model_checkpoint_dir is not None + ), "Cannot require realtime patching without an entropy model checkpoint" + entropy_model = load_entropy_model( + patcher_args.entropy_model_checkpoint_dir + ) + entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) + self.entropy_model = entropy_model + else: + self.entropy_model = None + self.threshold = patcher_args.threshold + self.threshold_add = patcher_args.threshold_add + self.max_patch_length = patcher_args.max_patch_length + self.patch_size = patcher_args.patch_size + self.patching_batch_size = patcher_args.patching_batch_size + self.data_loader_patching = patcher_args.data_loader_patching + self.device = patcher_args.device + self.monotonicity = patcher_args.monotonicity + self.log_time = patcher_args.log_time + if self.log_time: + self.log = defaultdict(float) + + def patch( + self, + tokens: torch.Tensor, + include_next_token: bool = False, + preds: torch.Tensor | None = None, + entropies: torch.Tensor | None = None, + threshold: float = None, + ) -> torch.Tensor: + """ + tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched + Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) + -> output tensor: [batch_size, max_num_patches] + each tensor is processed independently and gets right padded with zeros. + + Patching with the following modes: + 1. patching_mode = None: static patch size + 2. patching_mode = "entropy": + calculate entropy of each token, allocate patches so that the total + number of patches is the same as static patching but choose to begin + patches on tokens where the model is most uncertain (highest entropy). + + When threshold is provided, it uses the threshold to decide when to + start a new patch. + 3. patching_mode = "space": + use space like tokens to define the patches. + 4. patching_mode = "bpe": + use bpe delim tokens to define the patches. + + To correctly patch the last token, it may be necessary to include the next token in the patch + lengths calculations. This is controlled by the include_next_token argument. + """ + bs, seq_len = tokens.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + scores = None + # STATIC + if self.patching_mode is None: + patch_lengths = torch.zeros( + (bs, math.ceil(seq_len_next_tok / self.patch_size)), + dtype=tokens.dtype, + device=tokens.device, + ).fill_(self.patch_size) + if seq_len_next_tok % self.patch_size != 0: + patch_lengths[:, -1] = seq_len_next_tok % self.patch_size + # ENTROPY + elif self.patching_mode == PatchingModeEnum.entropy: + if self.log_time: + s = time.time() + if entropies is not None: + scores = torch.tensor(entropies, dtype=torch.float32) + elif preds is not None: + scores = entropy(preds) + else: + start_entropies = time.time() + scores = calculate_entropies( + tokens, + self.entropy_model, + self.patching_batch_size, + self.device, + ) + if self.log_time: + self.log["calculate_entropies"] += time.time() - s + s = time.time() + patch_start_ids = find_entropy_patch_start_ids( + scores, + self.patch_size, + include_next_token=include_next_token, + threshold=threshold if threshold is not None else self.threshold, + threshold_add=self.threshold_add, + monotonicity=self.monotonicity, + ) + if self.log_time: + self.log["find_entropy_patch_start_ids"] += time.time() - s + s = time.time() + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + if self.log_time: + self.log["patch_lengths_from_start_ids"] += time.time() - s + s = time.time() + # BPE + elif self.patching_mode == PatchingModeEnum.bpe: + patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID) + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + elif self.patching_mode == PatchingModeEnum.bpe_patcher: + patch_start_ids = find_bpe_patcher_patch_start_ids( + tokens, + self.entropy_model, + self.patching_batch_size, + self.device, + include_next_token, + ) + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + # SPACE + elif self.patching_mode == PatchingModeEnum.space: + patch_start_ids = find_space_patch_start_ids(tokens) + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + else: + raise NotImplementedError(f"self.patching_mode {self.patching_mode}") + + # Apply any processing to patch lengths + if self.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + split_large_numbers(pl, self.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=tokens.dtype, device=tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + assert ( + torch.sum(patch_lengths) + == tokens.numel() + include_next_token * tokens.shape[0] + ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" + if self.log_time: + self.log["postprocessing_patch_lengths"] += time.time() - s + self.log["tokens"] += patch_lengths.sum().item() + return patch_lengths, scores diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py new file mode 100644 index 0000000..fadce45 --- /dev/null +++ b/bytelatent/distributed.py @@ -0,0 +1,478 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import atexit +import contextlib +import logging +import multiprocessing as mp +import os +import random +import shutil +import signal +import socket +import subprocess +import sys +import tempfile +from dataclasses import asdict, dataclass +from functools import lru_cache, partial, reduce +from itertools import chain +from typing import List, Optional, Tuple, Union + +import torch + +# for no recompute ops +import xformers.ops +from pydantic import BaseModel, ConfigDict +from torch import distributed as dist +from torch.distributed import ReduceOp +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed._tensor import DTensor +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, +) +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, +) + +from bytelatent.float8 import convert_linears_to_fp8 + +logger = logging.getLogger() + +# for selective AC +default_no_recompute_ops = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.c10d_functional.reduce_scatter_tensor.default, + torch.ops.xformers_flash.flash_fwd.default, + torch.ops.xformers.efficient_attention_forward_cutlass.default, +} + + +class DistributedArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dp_shard: int = ( + 1 # In how many shard to split the model weight. Typically number gpu in a node. + ) + dp_replicate: int = ( + 1 # How many times to replicate the model weight. Typically number of nodes. + ) + tp_size: int = 1 + selective_activation_checkpointing: bool = False + compile: bool = False + fsdp_type: str = "no_shard" + model_dtype: str = "bf16" + float8_recipe: str | None = None + float8_filter: str = r"layers\.[0-9]+\." + + matmul_allow_tf32: bool = False + allow_bf16_reduced_precision_reduction: bool = True + detect_anomaly: bool = False + + compile_cache_size_limit: int = 8 + + spawn_method: str = "forkserver" + + +class EnvironmentArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Use GNU openMP (GOMP) instead of Intel OpenMP [Intel Math Kernel Library (MKL)] + MKL_SERVICE_FORCE_INTEL: str = "GNU" + OMP_NUM_THREADS: str = "1" + MKL_NUM_THREADS: str = "1" + # faster intra-node collectives, seems to be a cluster specific flag + ENABLE_INTRA_NODE_COMM: str = "1" + # avoids OOMs with long context + TORCH_NCCL_AVOID_RECORD_STREAMS: str = "1" + # increasing NCCL timeout time before having some NCCL error 22 should give a 16s timeout + NCCL_IB_TIMEOUT: str = "22" + NCCL_DEBUG: str = "INFO" + TORCH_NCCL_ASYNC_ERROR_HANDLING: str = "1" + + +def get_device_mesh(distributed_args: DistributedArgs): + tp_size = distributed_args.tp_size + dp_replicate = distributed_args.dp_replicate + dp_shard = distributed_args.dp_shard + + assert ( + dp_replicate * dp_shard * tp_size == get_world_size() + ), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})" + + dims = [] + names = [] + if dp_replicate >= 1: + dims.append(dp_replicate) + names.append("dp_replicate") + if dp_shard > 1 or distributed_args.fsdp_type == "no_shard": + dims.append(dp_shard) + names.append("dp_shard") + if tp_size > 1: + dims.append(tp_size) + names.append("tp") + dims = tuple(dims) + names = tuple(names) + + return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names) + + +def dist_max(x: Union[int, float], mesh: DeviceMesh = None): + tensor = torch.tensor(x).cuda() + dist.all_reduce(tensor, op=ReduceOp.MAX, group=mesh.get_group() if mesh else None) + return tensor + + +def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): + tensor = torch.tensor(x).cuda() + dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) + return tensor + + +def dist_mean_dict(x): + r = dict() + for k in x: + r[k] = dist_mean(x[k]) + r[k] = r[k].item() if (r[k].dim() == 0) else r[k].tolist() + return r + + +@lru_cache() +def get_is_torch_run() -> bool: + return os.environ.get("LOCAL_RANK") is not None + + +@lru_cache() +def get_is_slurm_job() -> bool: + return "SLURM_JOB_ID" in os.environ and not get_is_torch_run() + + +@lru_cache() +def get_global_rank() -> int: + if get_is_torch_run(): + return int(os.environ["RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_PROCID"]) + else: + return 0 + + +@lru_cache() +def get_local_rank() -> int: + if get_is_torch_run(): + return int(os.environ["LOCAL_RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_LOCALID"]) + else: + return 0 + + +@lru_cache() +def get_world_size() -> int: + if get_is_torch_run(): + return int(os.environ["WORLD_SIZE"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_NTASKS"]) + else: + return 1 + + +@lru_cache() +def get_is_master() -> bool: + return get_global_rank() == 0 + + +@lru_cache() +def get_master_port(job_id: int) -> int: + if get_is_torch_run(): + return int(os.environ["MASTER_PORT"]) + else: + MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000) + rng = random.Random(job_id) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + +@lru_cache() +def get_master_addr() -> str: + if get_is_torch_run(): + return os.environ["MASTER_ADDR"] + elif get_is_slurm_job(): + hostnames = subprocess.check_output( + ["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]] + ) + return hostnames.split()[0].decode("utf-8") + else: + return "127.0.0.1" + + +def setup_env(env_args: EnvironmentArgs): + env_vars = env_args.model_dump() + + # When using Triton, it attempts to locate prebuilt kernels in a cache + # located at ~/.triton/cache, but when that's backed by NFS this can fail + # with a "OSError: [Errno 116] Stale file handle" error. If we were to set + # it to a local directory it would belong to the first user who created it + # and it would fail for the job of any other successive user assigned to + # that machine. To avoid all this mess we use a temporary per-process cache. + triton_cache_dir = tempfile.mkdtemp() + atexit.register(shutil.rmtree, triton_cache_dir, ignore_errors=True) + env_vars["TRITON_CACHE_DIR"] = triton_cache_dir + + # We change the tmp dir to /scratch in case it's slurm job + # This avoids filling up the host's usually limited tmpfs + # A full tmpfs leads to very slow creation of processes and weird bugs + if get_is_slurm_job(): + new_tmp = f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}" + if os.path.exists(new_tmp): + env_vars["TMP_DIR"] = new_tmp + + for name, value in env_vars.items(): + if os.environ.get(name) != str(value): + os.environ[name] = str(value) + logger.warning(f"WARNING: Setting {name} to {value}") + + +def setup_torch_distributed(dist_args): + """ + Handle single and multi-GPU / multi-node / SLURM jobs. + Initialize the following variables: + - global_rank + - world_size + """ + mp.set_start_method(dist_args.spawn_method) + with mp.Manager(): + pass + + local_rank = get_local_rank() + + os.environ["RANK"] = str(get_global_rank()) + os.environ["WORLD_SIZE"] = str(get_world_size()) + os.environ["MASTER_ADDR"] = get_master_addr() + os.environ["MASTER_PORT"] = str( + get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1))) + ) + + if get_is_torch_run(): + logger.info(f"Run launched with torchrun, local rank: {local_rank}") + elif get_is_slurm_job(): + logger.info(f"Run launched with slurm, local rank: {local_rank}") + else: + logger.info("Single GPU job") + + logger.info(f"ENV: {os.environ}") + + # set GPU device + assert 0 <= local_rank < 8 + if dist_args.matmul_allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + logger.warning( + f"WARNING: Setting torch.backends.matmul.allow_tf32 to True. This is faster but less accurate." + ) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + dist_args.allow_bf16_reduced_precision_reduction + ) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(init_method="env://", backend="nccl") + torch.autograd.set_detect_anomaly(dist_args.detect_anomaly) + + +def get_module(module, access_string): + names = access_string.split(sep=".") + return reduce(getattr, names, module) + + +def set_module(module, access_string, value): + names = access_string.split(sep=".") + parent = reduce(getattr, names[:-1], module) + setattr(parent, names[-1], value) + + +def default_fsdp_grouping_plan(n_layers: int) -> List[Tuple[str, bool]]: + return [(f"layers.{i}", i < n_layers - 1) for i in range(n_layers)] + + +def get_default_policy(no_recompute_ops=None): + no_recompute_ops = no_recompute_ops or default_no_recompute_ops + + def default_policy(ctx, func, *args, **kwargs): + return ( + CheckpointPolicy.MUST_SAVE + if func in no_recompute_ops + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return default_policy + + +@torch.no_grad() +def check_model_value_range( + model: torch.nn.Module, range: float = 1e3, std: float = 1e3 +): + for name, param in chain(model.named_parameters(), model.named_buffers()): + if isinstance(param, DTensor): + param = param.to_local() + + if param.numel() == 0: + logger.warning( + f"Model parameter {name} is empty, probably because of FSDP sharding" + ) + continue + + if torch.isnan(param).any() or torch.isinf(param).any(): + logger.warning(f"Model parameter {name} contains NaN or Inf") + + param_range = param.max() - param.min() + param_std = param.std() + if param_range > range: + logger.warning( + f"Model parameter {name} has a suspiciously large range ({param_range}): please check initialization and init_weights is defined and called" + ) + if param_std > std: + logger.warning( + f"Model parameter {name} has a suspiciously large standard deviation ({param_std}): please check initialization and init_weights is defined and called" + ) + if (param == 0).all(): + logger.warning( + f"Model parameter {name} is all zeros: it might be because of a missing initialization" + ) + + +def init_signal_handler(callable): + """ + Handle signals sent by SLURM for time limit / pre-emption. + """ + signal.signal(signal.SIGUSR2, callable) + logger.warning("Signal handler installed.") + + +def requeue_slurm_job(): + prod_id = int(os.environ["SLURM_PROCID"]) + logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) + if prod_id == 0 and os.environ.get("LAUNCH_WITH", "") != "DORA": + logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"]) + os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) + else: + logger.warning("Not the master process, no need to requeue.") + sys.exit(0) + + +@contextlib.contextmanager +def clean_env(): + distrib_names = ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", + "TORCHELASTIC_RUN_ID", + "DORA_FORCE_DISTRIB", + ) + cluster_env = { + x: os.environ.pop(x) + for x in os.environ + if x.startswith( + ("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_", "WANDB_") + ) + or x in distrib_names + } + try: + yield + finally: + os.environ.update(cluster_env) + + +def parallelize_model( + model, + device_mesh, + model_args, + distributed_args: DistributedArgs, + fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, + tp_parallelize=None, + no_recompute_ops=None, +): + if distributed_args.tp_size > 1: + assert ( + distributed_args.fsdp_type == "full_shard" + ), "Only full shard is supported for TP parallelism" + assert tp_parallelize is not None, "TP plan is required for TP parallelism" + assert ( + distributed_args.compile == False + ), "Compile is not supported for TP parallelism" + + tp_parallelize(model, device_mesh["tp"], model_args, distributed_args) + + if distributed_args.float8_recipe is not None: + if distributed_args.tp_size > 1: + raise RuntimeError("float8 is incompatible with tensor-parallelism for now") + model = convert_linears_to_fp8( + model, distributed_args.float8_recipe, distributed_args.float8_filter + ) + + param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ + distributed_args.model_dtype + ] + if ( + distributed_args.fsdp_type == "full_shard" + or distributed_args.fsdp_type == "no_shard" + ): + if distributed_args.fsdp_type == "no_shard": + assert ( + distributed_args.dp_shard == 1 + ), "dp_shard must be 1 for no_shard fsdp_type" + assert ( + device_mesh["dp_shard"].size() == 1 + ), "dp_shard must be 1 for no_shard fsdp_type" + + fsdp_config = dict( + mp_policy=( + MixedPrecisionPolicy( + param_dtype=param_dtype, + reduce_dtype=torch.float32, + ) + ), + mesh=( + device_mesh["dp_replicate", "dp_shard"] + if distributed_args.dp_shard > 1 + or distributed_args.fsdp_type == "no_shard" + else device_mesh["dp_replicate"] + ), + ) + + if fsdp_grouping_plan is None: + # Assume that the model has list of layers and group around it + fsdp_grouping_plan = default_fsdp_grouping_plan(len(model.layers)) + + for path, reshard_after_forward in fsdp_grouping_plan: + module = get_module(model, path) + set_module( + model, + path, + fully_shard( + module, **fsdp_config, reshard_after_forward=reshard_after_forward + ), + ) + + model = fully_shard(model, **fsdp_config, reshard_after_forward=True) + else: + raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}") + + if distributed_args.selective_activation_checkpointing: + model = checkpoint_wrapper( + model, + context_fn=partial( + create_selective_checkpoint_contexts, + get_default_policy(no_recompute_ops), + ), + ) + + if distributed_args.compile: + torch._dynamo.config.cache_size_limit = ( + distributed_args.compile_cache_size_limit + ) + model = torch.compile(model) + + return model diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py new file mode 100644 index 0000000..1bd1766 --- /dev/null +++ b/bytelatent/entropy_model.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import json +import os +import re + +import torch + +from bytelatent.transformer import LMTransformer, LMTransformerArgs + + +def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): + with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: + reloaded = json.loads(fr.read()) + + torch.set_default_dtype(torch.bfloat16) + model_params = reloaded["model"] + entropy_model = LMTransformer( + LMTransformerArgs( + dim=model_params["dim"], + n_layers=model_params["n_layers"], + n_heads=model_params["n_heads"], + max_seqlen=model_params["max_length"], + ffn_dim_multiplier=model_params["ffn_dim_multiplier"], + vocab_size=model_params["vocab_size"], + ) + ) + + entropy_model.load_state_dict( + torch.load(state_dict_path, map_location=device), strict=False + ) + entropy_model.to(device) + entropy_model = entropy_model.eval() + # no grads for the model: + for param in entropy_model.parameters(): + param.requires_grad = False + return entropy_model diff --git a/bytelatent/float8.py b/bytelatent/float8.py new file mode 100644 index 0000000..6476862 --- /dev/null +++ b/bytelatent/float8.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import re +import warnings +from typing import Callable + +import torch + +# avoid division by zero when calculating scale +EPS = 1e-12 + + +def scale(t, amax_t, dtype_t): + min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max + scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v + t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t) + return t_fp8, scale_t + + +def matmul( + first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias +): + first_fp8, scale_first = scale(first, amax_first, dtype_first) + second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t) + output = torch._scaled_mm( + first_fp8, + second_t_fp8.t(), + scale_a=scale_first, + scale_b=scale_second_t.t(), + bias=bias, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + return output + + +@torch._dynamo.allow_in_graph +class Fp8LinearFn(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b_t, bias): + amax_a = a.abs().amax(dim=-1, keepdim=True) + amax_b_t = b_t.abs().amax(dim=-1, keepdim=True) + out = matmul( + a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias + ) + + ctx.a_requires_grad = a.requires_grad + ctx.b_requires_grad = b_t.requires_grad + ctx.bias_requires_grad = bias.requires_grad if bias is not None else False + + ctx.save_for_backward(a, b_t, amax_b_t.max()) + + return out + + @staticmethod + def backward(ctx, grad_out): + a, b_t, amax_b = ctx.saved_tensors + + if ctx.a_requires_grad: + b = b_t.t().contiguous() + amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True) + amax_b = amax_b.repeat(b.shape[0], 1) + grad_a = matmul( + grad_out, + amax_grad_out, + torch.float8_e4m3fn, + b, + amax_b, + torch.float8_e4m3fn, + None, + ) + else: + grad_a = None + if ctx.b_requires_grad: + grad_b = grad_out.t() @ a + else: + grad_b = None + if ctx.bias_requires_grad: + grad_bias = grad_out.sum(dim=0) + else: + grad_bias = None + + return grad_a, grad_b, grad_bias + + +class Fp8Linear(torch.nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias) + out = out.unflatten(0, input.shape[:-1]) + return out + + +def named_replace( + fn: Callable[[torch.nn.Module, str], torch.nn.Module], + module: torch.nn.Module, + name="", +) -> torch.nn.Module: + for child_name, child_module in list(module.named_children()): + full_name = f"{name}.{child_name}" if name else child_name + new_child_module = named_replace(fn, child_module, full_name) + setattr(module, child_name, new_child_module) + module = fn(module, name) + return module + + +def convert_linears_to_fp8( + root_module: torch.nn.Module, recipe: str, filter: str +) -> torch.nn.Module: + if recipe not in ["rowwise"]: + raise RuntimeError(f"Unknown float8 recipe {recipe!r}") + + if recipe == "rowwise" and torch.__version__ < "2.5": + # We need https://github.com/pytorch/pytorch/pull/134781. + warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0") + + # Multi-kernel makes Inductor auto-tune between a regular "streaming"-based + # reduction kernel and a "persistent" reduction kernel. Since fp8 has some + # multi-pass steps (e.g., first get amax, then scale), persistent kernels + # should perform better. + torch._inductor.config.triton.multi_kernel = 1 + + filter_re = re.compile(filter) + + def replace(module: torch.nn.Module, name: str) -> torch.nn.Module: + if not isinstance(module, torch.nn.Linear) or not filter_re.search(name): + return module + if type(module) == torch.nn.Linear: + if recipe == "rowwise": + new_module = Fp8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + dtype=module.weight.dtype, + device=module.weight.device, + ) + new_module.weight = module.weight + new_module.bias = module.bias + else: + assert False, recipe + else: + assert False, str(type(module)) + return new_module + + out = named_replace(replace, root_module) + + # Force re-compile everything + torch._dynamo.reset_code_caches() + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + + return out diff --git a/bytelatent/logger.py b/bytelatent/logger.py new file mode 100644 index 0000000..6723a84 --- /dev/null +++ b/bytelatent/logger.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +import math +import sys +import time +from datetime import timedelta + +from bytelatent.distributed import get_global_rank, get_is_slurm_job + + +class LogFormatter(logging.Formatter): + """ + Custom logger for distributed jobs, displaying rank + and preserving indent from the custom prefix format. + """ + + def __init__(self): + self.start_time = time.time() + self.rank = get_global_rank() + self.show_rank = not get_is_slurm_job() # srun has --label + + def formatTime(self, record): + subsecond, seconds = math.modf(record.created) + curr_date = ( + time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) + + f".{int(subsecond * 1_000_000):06d}" + ) + delta = timedelta(seconds=round(record.created - self.start_time)) + return f"{curr_date} - {delta}" + + def formatPrefix(self, record): + fmt_time = self.formatTime(record) + if self.show_rank: + return f"{self.rank}: {record.levelname:<7} {fmt_time} - " + else: + return f"{record.levelname:<7} {fmt_time} - " + + def formatMessage(self, record, indent: str): + content = record.getMessage() + content = content.replace("\n", "\n" + indent) + # Exception handling as in the default formatter, albeit with indenting + # according to our custom prefix + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if content[-1:] != "\n": + content = content + "\n" + indent + content = content + indent.join( + [l + "\n" for l in record.exc_text.splitlines()] + ) + if content[-1:] == "\n": + content = content[:-1] + if record.stack_info: + if content[-1:] != "\n": + content = content + "\n" + indent + stack_text = self.formatStack(record.stack_info) + content = content + indent.join([l + "\n" for l in stack_text.splitlines()]) + if content[-1:] == "\n": + content = content[:-1] + + return content + + def format(self, record): + prefix = self.formatPrefix(record) + indent = " " * len(prefix) + content = self.formatMessage(record, indent) + return prefix + content + + +def set_root_log_level(log_level: str): + logger = logging.getLogger() + level: int | str = log_level.upper() + try: + level = int(log_level) + except ValueError: + pass + try: + logger.setLevel(level) # type: ignore + except Exception: + logger.warning( + f"Failed to set logging level to {log_level}, using default 'NOTSET'" + ) + logger.setLevel(logging.NOTSET) + + +def init_logger( + log_file: str | None = None, + *, + name: str | None = None, + level: str = "NOTSET", +): + """ + Setup logging. + + Args: + log_file: A file name to save file logs to. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + """ + set_root_log_level(level) + logger = logging.getLogger(name) + + # stdout: everything + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.NOTSET) + stdout_handler.setFormatter(LogFormatter()) + + # stderr: warnings / errors and above + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.WARNING) + stderr_handler.setFormatter(LogFormatter()) + + # set stream handlers + logger.handlers.clear() + logger.handlers.append(stdout_handler) + logger.handlers.append(stderr_handler) + + if log_file is not None and get_global_rank() == 0: + # build file handler + file_handler = logging.FileHandler(log_file, "a") + file_handler.setLevel(logging.NOTSET) + file_handler.setFormatter(LogFormatter()) + # update logger + logger = logging.getLogger() + logger.addHandler(file_handler) diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py new file mode 100644 index 0000000..77dc4d7 --- /dev/null +++ b/bytelatent/metrics.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +import logging +from collections import namedtuple +from dataclasses import asdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Union + +import torch +import torch.nn as nn +import wandb +from pydantic import BaseModel, ConfigDict + +from bytelatent.distributed import get_is_master + +logger = logging.getLogger() + + +class WandbArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + job_type: str | None = None + dir: str | None = None + project: str | None = None + entity: str | None = None + tags: list | None = None + group: str | None = None + name: str | None = None + notes: str | None = None + config_exclude_keys: list[str] | None = None + config_include_keys: list[str] | None = None + anonymous: str | None = None + mode: str | None = None + allow_val_change: bool | None = None + resume: Union[bool, str] | None = None + force: bool | None = None + tensorboard: bool | None = None + sync_tensorboard: bool | None = None + monitor_gym: bool | None = None + save_code: bool | None = None + id: str | None = None + fork_from: str | None = None + resume_from: str | None = None + + +class LoggingArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + freq: int = 10 # Log every freq optimizer steps + acc_freq: int | None = None # Log every acc_freq gradient accumulation steps + + wandb: WandbArgs | None = None + + +class MetricLogger: + def __init__(self, outdir: Path, args: Any | None = None): + self.outdir = outdir + self.jsonl_writer = None + self.args = args + + def open(self): + if self.jsonl_writer is None: + self.jsonl_writer = open(self.outdir, "a") + if ( + self.args is not None + and self.args.logging.wandb is not None + and get_is_master() + ): + run = wandb.init( + config=asdict(self.args), + **asdict(self.args.logging.wandb), + ) + + def log(self, metrics: dict[str, Any]): + if ( + self.args is not None + and self.args.logging.wandb is not None + and (wandb.run is not None) + ): + wandb.log(metrics, step=metrics["global_step"]) + + metrics.update({"created_at": datetime.now(timezone.utc).isoformat()}) + print(json.dumps(metrics), file=self.jsonl_writer, flush=True) + + def close(self): + if self.jsonl_writer is not None: + self.jsonl_writer.close() + self.jsonl_writer = None + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + +GPUMemStats = namedtuple( + "GPUMemStats", + [ + "max_active_gib", + "max_active_pct", + "max_reserved_gib", + "max_reserved_pct", + "num_alloc_retries", + "num_ooms", + "power_draw", + ], +) + + +class GPUMemoryMonitor: + """ + Class to monitor GPU memory usage + """ + + def __init__(self, device: str = "cuda:0"): + self.device = torch.device(device) # device object + self.device_name = torch.cuda.get_device_name(self.device) + self.device_index = torch.cuda.current_device() + self.device_capacity = torch.cuda.get_device_properties( + self.device + ).total_memory + self.device_capacity_gib = self._to_gib(self.device_capacity) + + # reset stats, clear cache + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + def _to_gib(self, memory_in_bytes): + # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 + _gib_in_bytes = 1024 * 1024 * 1024 + memory_in_gib = memory_in_bytes / _gib_in_bytes + return memory_in_gib + + def _to_pct(self, memory): + return 100 * memory / self.device_capacity + + def get_peak_stats(self): + cuda_info = torch.cuda.memory_stats(self.device) + + max_active = cuda_info["active_bytes.all.peak"] + max_active_gib = self._to_gib(max_active) + max_active_pct = self._to_pct(max_active) + + max_reserved = cuda_info["reserved_bytes.all.peak"] + max_reserved_gib = self._to_gib(max_reserved) + max_reserved_pct = self._to_pct(max_reserved) + + num_retries = cuda_info["num_alloc_retries"] + num_ooms = cuda_info["num_ooms"] + power_draw = torch.cuda.power_draw() + + if num_retries > 0: + logger.warning(f"{num_retries} CUDA memory allocation retries.") + if num_ooms > 0: + logger.warning(f"{num_ooms} CUDA OOM errors thrown.") + + return GPUMemStats( + max_active_gib, + max_active_pct, + max_reserved_gib, + max_reserved_pct, + num_retries, + num_ooms, + power_draw, + ) + + def reset_peak_stats(self): + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + + def __str__(self): + mem_stats = self.get_peak_stats() + display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, " + display_str += ( + f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak" + ) + return f"{display_str}" + + +def upload_train_to_wandb( + ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True +): + import json + from pathlib import Path + + import wandb + from omegaconf import OmegaConf + + cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml") + cfg = OmegaConf.to_container(cfg) + + if train: + wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity) + + with open(Path(ckpt_dir) / "metrics.jsonl") as f: + for l in f: + m = json.loads(l) + wandb.log(m, step=m["global_step"]) + + wandb.finish() + + if eval: + wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity) + + with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f: + for l in f: + m = json.loads(l) + wandb.log( + { + f"evals/{name.replace('/','.')}": value + for name, value in m.items() + if "/" in name + }, + step=m["global_step"], + ) + + wandb.finish() + + +def get_num_params(model: nn.Module) -> int: + """ + Get the total model params + Args : only_trainable: whether to only count trainable params + """ + numel = {n: p.numel() for n, p in model.named_parameters()} + return sum(numel.values()) diff --git a/bytelatent/model/__init__.py b/bytelatent/model/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/model/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py new file mode 100644 index 0000000..9332d19 --- /dev/null +++ b/bytelatent/model/blt.py @@ -0,0 +1,1064 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from enum import Enum, auto +from typing import Any, Optional + +import torch +from pydantic import ConfigDict, model_validator +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask +from typing_extensions import Self + +from bytelatent.base_transformer import ( + BaseTransformerArgs, + InitStdFactor, + TransformerBlock, +) +from bytelatent.data.patcher import Patcher, PatcherArgs +from bytelatent.model.local_models import LocalDecoder, LocalEncoder +from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.utils import downsample +from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID + + +def attention_flops_per_token(n_layers, seq_len, dim, causal): + # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 + return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1)) + + +def get_num_flop_per_token( + num_non_embed_params: int, n_layers: int, dim: int, seq_len: int +) -> int: + return 6 * num_non_embed_params + attention_flops_per_token( + n_layers, seq_len, dim, True + ) + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def setattrs(_self, **kwargs): + for k, v in kwargs.items(): + setattr(_self, k, v) + + +def get_encoder_dim_token_emb(args): + if args.dim_token is not None: + dim_token_emb = args.dim_token + elif args.use_local_encoder_transformer: + dim_token_emb = args.dim_local_encoder + else: + dim_token_emb = args.dim_global // args.patch_size + return dim_token_emb + + +def get_encoder_dim_patch_emb(args): + dim_patch_emb = None + if args.cross_attn_encoder: + if args.cross_attn_init_by_pooling: + dim_patch_emb = args.dim_local_encoder + else: + dim_patch_emb = args.dim_global + return dim_patch_emb + + +def get_global_dim_patch_emb(args): + dim_token_emb = get_encoder_dim_token_emb(args) + if args.cross_attn_encoder: + dim_patch_emb = dim_token_emb * args.cross_attn_k + elif ( + args.downsampling_by_pooling is None + or not args.downsampling_by_pooling + or len(args.downsampling_by_pooling) == 0 + ): + dim_patch_emb = dim_token_emb * args.patch_size + else: + dim_patch_emb = dim_token_emb * sum( + [ + pooling in args.downsampling_by_pooling + for pooling in ["avg", "min", "max"] + ] + ) + return dim_patch_emb + + +def get_decoder_dim_token_emb(args): + if args.share_encoder_decoder_emb: + dim_token_emb = get_encoder_dim_token_emb(args) + elif args.dim_token is not None: + dim_token_emb = args.dim_token + else: + dim_token_emb = args.dim_local_decoder + return dim_token_emb + + +def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: + if ngram_to_size_str is None: + return None + ngram_to_size = {} + for entry in ngram_to_size_str.split(","): + ngram, size = entry.split(":") + ngram = int(ngram) + size = int(size) + ngram_to_size[ngram] = size + return ngram_to_size + + +def fill_tokens(tokens, patch_size, fill_id): + batch_size, seq_len = tokens.shape + if seq_len % patch_size == 0: + return tokens + else: + remaining = patch_size - seq_len % patch_size + final_padding = tokens.new(batch_size, remaining).fill_(fill_id) + return torch.cat((tokens, final_padding), dim=1) + + +def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): + first_patch_length = patch_lengths[0, 0] + assert torch.all( + first_patch_length == patch_lengths[:, 0] + ), "first patch should always be the same size (1 for dynamic, patch_size for static)." + assert ( + first_patch_length - nb_boe == 1 + ), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})" + # Remove first patch from patch_ids for local decoder inputs and shift the last patch. + # decoder_patch_lengths = patch_lengths[:, 1:].clone() + # decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1) + decoder_patch_lengths = patch_lengths[:, 1:] + assert ( + decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0] + == patch_lengths.sum() + ), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}" + assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}" + decoder_patch_ids = patch_ids_from_lengths( + patch_lengths=decoder_patch_lengths, seq_len=seq_len + ) + return decoder_patch_ids + + +primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, +] + + +def rolling_polynomial_hash(t, hash_func_nb: int = 0): + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) + + +def get_rolling_polynomial_hash_fn(hash_func_nb: int = 0, group_size: int = 2): + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64) + prime_powers = torch.stack([prime**i for i in range(group_size)]) + + def rolling_polynomial_hash_fn(t): + return torch.sum(t * prime_powers, dim=-1) + + return rolling_polynomial_hash_fn + + +def byte_group_hash_function( + x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): + """ + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + bs, seq_len = x.shape + # x_numpy = x.numpy() + # hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False) + # for i in range(bs): + # for j in range(seq_len): + # start = max(j, j-group_size+1) + # end = j+1 + # hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash) + + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def create_patch_mask_from_ids( + patch_ids, num_patches, window=None, patches_as_queries=False +): + """ + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + is True if the patch id at position (i, j) is less than or equal to k. + Args: + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + num_patches (int): Total number of patches. + window (int): If not None, only considers patches within a window of size window. + patches_as_queries (bool): If True, the patches are used as queries + Returns: + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + """ + bs, seq_len = patch_ids.shape + if not patches_as_queries: + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + kv_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(bs, seq_len, num_patches) + ) + else: + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + q_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(bs, num_patches, seq_len) + ) + if window is None: + mask = q_ids == kv_ids + else: + mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) + return mask + + +def cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=1, + window=None, + block_mask=True, +): + bs = patch_ids.shape[0] + with torch.no_grad(): + # Create the patch mask + cross_mask = create_patch_mask_from_ids( + patch_ids, + patch_lengths.shape[1], + window=window, + patches_as_queries=patches_as_queries, + ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) + q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N + kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k + assert cross_mask.shape == ( + bs, + q_len, + kv_len, + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + if block_mask: + + def patch_mask(b, h, q_idx, kv_idx): + return cross_mask[b, q_idx, kv_idx] + + block_mask = create_block_mask( + patch_mask, + B=bs, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + _compile=True, + ) + return block_mask + else: + return torch.where( + cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) + ).unsqueeze( + 1 + ) # [bs, 1, q_len, kv_len] + + +def get_blt_input( + tokens: torch.Tensor, + enforce_patch_size_multiple: bool, + nb_boe: torch.Tensor, + patch_size: int, + boe_id: int, +): + """ + This function returns X_et, X_gt and X_dt, the encoder, global, and decoder + tokens respectively. + + Consider the input and target sequences: + X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] + Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] + with patch_size=4 + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + + Current without boe: + X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] + X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + --> lag fix: + X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] + X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + Dynamic (current): + X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + entropy patching: + input: 7, bos, 9, 10 + pred (high entropy): eos, 8, 10, eos + + X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] + X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] + X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + --> lag fix no boe (force single byte first patch): + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + input: 4, 7, bos, 9, 10 + pred (high entropy): 5, eos, 8, 10, eos + + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + Handle the last byte properly. + patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] + + + bpe delim + X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] + X_g = [[3], [4,5,6,7,], [eos,bos,], .. + X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. + Y = [4,5,6,7,, eos,bos, 8,9,, .. + + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + """ + batch_size, seq_len = tokens.shape + local_encoder_tokens = tokens + local_decoder_tokens = tokens + + if nb_boe > 0: + padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) + local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) + # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) + + # create global tokens, contains boe tokens and eos + # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) + # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] + # global_tokens += global_tokens.eq(0).int() * boe_id + # TODO: fix this when we want to use block causal in the global. + + if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: + local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + + return local_encoder_tokens, None, local_decoder_tokens + + +def patch_ids_from_lengths(patch_lengths, seq_len): + bs, num_patches = patch_lengths.shape + # Create a tensor of cumulative sums of the patch lengths + cum_d = torch.cat( + [ + torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1), + ], + dim=-1, + ) + patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum( + dim=-2 + ) - 1 + assert not ( + torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0 + ), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0" + return patch_ids + + +class ByteLatentTransformerArgs(BaseTransformerArgs): + model_config = ConfigDict(extra="forbid") + # Basic model configuration + seed: int = 42 + vocab_size: int = -1 + dim: int = 512 + n_layers: int = 8 + n_heads: int = 8 + # TODO: What is the purpose of this parameter? + weight_tying: bool = False + sliding_window: Optional[int] = None + + # Architecture and dimensions + dim_token: int = 256 + dim_global: int = 512 + dim_local_decoder: int = 512 + dim_local_encoder: int = 512 + n_layers_global: int = 8 + n_layers_local_decoder: int = 8 + n_layers_local_encoder: int = 8 + + # Tokenization and patching + tokenization_mode: str = "bpe" + patch_size: float | None = None + patching_mode: str | None = None + patching_threshold: float | None = None + patching_threshold_add: float | None = None + monotonicity: bool = False + patching_batch_size: int = 1 + patching_device: str = "cuda" + data_loader_patching: bool = False + max_patch_length: int | None = None + + # Encoder/Decoder configuration + tie_local_encoder_decoder_logits: bool = False + use_local_encoder_transformer: bool = False + encoder_lm_loss: bool = False + max_encoder_seq_length: int | None = None + pad_to_max_length: bool = False + encoder_enable_byte_ngrams: bool = False + encoder_enable_byte_group_hash: bool = False + ngram_vocab_sizes: int | None = None + + # Cross attention configurations + cross_attn_encoder: bool = False + cross_attn_decoder: bool = False + cross_attn_window_encoder: int | None = None + cross_attn_window_decoder: int | None = None + cross_attn_k: int | None = None + cross_attn_nheads: int | None = None + cross_attn_all_layers_decoder: bool = False + cross_attn_all_layers_encoder: bool = False + cross_attn_use_flex_attention: bool = True + cross_attn_init_by_pooling: bool = False + + # Encoder hash configurations + encoder_hash_byte_group_size: Any | None = None + encoder_hash_byte_group_vocab: int = 30000 + encoder_hash_byte_group_nb_functions: int = 3 + + # Model behavior and optimization + log_patch_lengths: bool = False + non_linearity: str = "swiglu" + use_rope: bool = True + recompute_fc1_out: bool = False + recompute_fc3_out: bool = False + recompute_attn: bool = True + custom_bwd: bool = False + layer_ckpt: str = "all" + efficient_attn: str | None = None + + # Architecture options + patch_only_encoder: bool = False + patch_only_decoder: bool = False + + # Initialization and attention + init_use_gaussian: bool = True + init_use_depth: str = "current" + attn_bias_type: str = "causal" + alpha_depth: str = "disabled" + max_length: int = 2048 + + # Norm configuration + norm_eps: float = 1e-5 + norm_affine: bool = True + pre_norm: bool = True + norm_type: str = "rmsnorm" + + # Additional configurations + multiple_of: int = 256 + ffn_dim_multiplier: float = 1.0 + dropout: float = 0 + output_size: int = -1 + + # Additional parameters from ModelArgs + architecture: str = "vanilla" + share_encoder_decoder_emb: bool = True + global_local_decoder_residual_layer: str | None = None + + tokenize_with_bpe_delimiter: bool = False + patching_thresholds_str: str | None = None + tie_local_encoder_decoder: bool = False + encoder_preds_low_entropy_toks: float | None = None + encoder_preds_random_toks: float | None = None + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + encoder_ngram_table_dir: str | None = None + encoder_ngram_to_size_str: str | None = None + + # Model architecture params + entropy_model_checkpoint_dir: str | None = None + entropy_model_is_ngram_model: bool = False + downsampling_by_pooling: str | None = None + n_heads_global: int = 8 + n_heads_local_decoder: int = 8 + n_heads_local_encoder: int = 8 + n_kv_heads: int | None = None + n_kv_heads_global: int | None = None + conv_kernel_size: int | None = None + local_attention_window_len: int | None = None + + # Performance optimization + sequence_parallel: bool = False + loss_parallel: bool = False + fuse_sequence_parallel: bool = False + use_fsdp: bool = True + attn_to_keep: str = "all" + + # RoPE parameters + rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False + + # Parameter mixing + pm_size: int = 0 + + # Logging + full_logging_n_layers: int = 4 + + # Special token config + eos_id: int | None = None + + @model_validator(mode="after") + def check_hash_byte_sizes(self) -> Self: + if ( + self.encoder_hash_byte_group_size is not None + and type(self.encoder_hash_byte_group_size) == str + ): + self.encoder_hash_byte_group_size = [ + int(x) + for x in self.encoder_hash_byte_group_size.split(",") + if len(x) > 0 + ] + return self + + +class LocalEncoderArgs(ByteLatentTransformerArgs): + # Local encoder specific dimensions + n_heads_local_encoder: int = 8 + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with local encoder specific values + self.dim = self.dim_local_encoder + self.n_layers = self.n_layers_local_encoder + self.n_heads = self.n_heads_local_encoder + self.cross_attn_decoder = False + self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None + self.attn_bias_type = "local_block_causal" + + +class GlobalTransformerArgs(ByteLatentTransformerArgs): + # Global encoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with global encoder specific values + self.dim = self.dim_global + self.n_layers = self.n_layers_global + self.n_heads = self.n_heads_global + self.n_kv_heads = self.n_kv_heads_global + self.local_attention_window_len = None + self.cross_attn_encoder = False + self.cross_attn_decoder = False + + +class LocalDecoderArgs(ByteLatentTransformerArgs): + # Local decoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with local decoder specific values + self.dim = self.dim_local_decoder + self.n_layers = self.n_layers_local_decoder + self.n_heads = self.n_heads_local_decoder + self.cross_attn_encoder = False + self.cross_attn_init_by_pooling = False + self.attn_bias_type = "local_block_causal" + + +def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransformer: + global_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_global, + n_layers=args.n_layers_global, + n_heads=args.n_heads_global, + n_kv_heads=args.n_kv_heads_global, + local_attention_window_len=None, + dim_token_emb=get_global_dim_patch_emb(args), + dim_patch_emb=None, + cross_attn_encoder=False, + cross_attn_decoder=False, + ), + ) + + return GlobalTransformer(global_args) + + +def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: + # First deep copy the original args + # Replace with local encoder specific values + local_encoder_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + attn_bias_type="local_block_causal", + ), + ) + + return LocalEncoder(local_encoder_args) + + +def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: + # First deep copy the original args + local_decoder_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + cross_attn_encoder=False, + cross_attn_init_by_pooling=False, # states are already defined + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + ), + ) + + return LocalDecoder(local_decoder_args) + + +class EmbeddingType(Enum): + HASH_TOK = auto() + NGRAM = auto() + + +def init_embeddings( + args, + embedding_type: EmbeddingType, + local_encoder_dim: int, + encoder_hash_byte_group_size: list = None, +): + if ( + embedding_type == EmbeddingType.HASH_TOK + and args.encoder_hash_byte_group_size is None + ): + return None + if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None: + return None + + embeddings = [] + + if embedding_type == EmbeddingType.HASH_TOK: + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + for _ in range(args.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + elif embedding_type == EmbeddingType.NGRAM: + encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str) + emb_dim = local_encoder_dim + OFFSET = 4 # This should be passed as parameter if it's variable + for ngram_vocab_size in encoder_ngram_to_size.values(): + embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) + + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with tok_embeddings method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class ByteLatentTransformer(nn.Module): + """ + The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences + by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, + and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for + improved performance and inference efficiency. + """ + + def __init__(self, args: ByteLatentTransformerArgs): + super().__init__() + + # General configuration + self.weight_tying = args.weight_tying + self.sliding_window = args.sliding_window + self.patch_size = args.patch_size + self.patching_mode = args.patching_mode + self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( + BOE_ID, + BOS_ID, + PAD_ID, + EOS_ID, + ) + self.downsampling_by_pooling = args.downsampling_by_pooling + self.patching_threshold = args.patching_threshold + self.dim = args.dim + self.init_base_std = args.init_base_std + self.init_std_factor = InitStdFactor(args.init_std_factor) + self.max_seqlen = args.max_seqlen + + # Cross attention configuration + self.cross_attn_encoder = args.cross_attn_encoder + self.cross_attn_decoder = args.cross_attn_decoder + self.cross_attn_k = args.cross_attn_k + self.cross_attn_window_encoder = args.cross_attn_window_encoder + self.cross_attn_window_decoder = args.cross_attn_window_decoder + self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention + + # Encoder hash configuration + self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size + self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = ( + args.encoder_hash_byte_group_nb_functions + ) + + # ByteLatent modules + self.local_encoder = create_local_encoder(args) + self.global_transformer = create_global_transformer(args) + self.local_decoder = create_local_decoder(args) + self.encoder_hash_tok_embedding = init_embeddings( + args, + EmbeddingType.HASH_TOK, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + ) + self.encoder_ngram_embedding = init_embeddings( + args, + EmbeddingType.NGRAM, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=None, + ) + self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + + # Transformer layers + self.layers = nn.ModuleList( + [TransformerBlock(args) for _ in range(args.n_layers)] + ) + + # Encoder ngram embedding tables + self.encoder_ngram_embedding = None + if args.encoder_enable_byte_ngrams: + self.encoder_ngram_embedding = nn.ModuleList() + assert args.ngram_vocab_sizes is not None + self.encoder_ngram_to_size = parse_ngram_to_size( + args.encoder_ngram_to_size_str + ) + ngram_emb_dim = self.local_encoder.dim + for ngram_vocab_size in self.encoder_ngram_to_size.values(): + self.encoder_ngram_embedding.append( + nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) + ) + + # Output layer + assert args.vocab_size > 0, "vocab_size must be greater than 0" + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + if args.weight_tying: + self.output.weight = self.tok_embeddings.weight + + # Patcher module + if not args.data_loader_patching: + self.patcher = Patcher( + PatcherArgs( + patch_size=args.patch_size, + patching_mode=args.patching_mode, + patching_threshold=args.patching_threshold, + patching_threshold_add=args.patching_threshold_add, + monotonicity=args.monotonicity, + max_patch_length=args.max_patch_length, + ) + ) + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ngram_ids: Optional[torch.Tensor] = None, + ): + # Ensure ngram_ids is either a tensor or None + assert ( + isinstance(ngram_ids, torch.Tensor) or ngram_ids is None + ), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}" + + bs, N = tokens.shape # Batch size and sequence length + + # Get megabyte inputs + nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1) + local_encoder_tokens, _, local_decoder_tokens = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=nb_boe, + patch_size=self.patch_size, + boe_id=self.boe_id, + ) + + # Patching + if patch_lengths is None: + assert ( + getattr(self, "patcher", None) is not None + ), "Patcher not defined and no patch_lengths passed." + patch_lengths, tok_scores = self.patcher.patch( + local_encoder_tokens, + include_next_token=True, + threshold=self.patcher.threshold, + ) + else: + if nb_boe > 0: + patch_lengths[:, 0] += nb_boe + + assert torch.min(patch_lengths) >= 0 + + # Generate patch IDs from patch_lengths + patch_ids = patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + assert torch.max(patch_ids) + 1 <= torch.max( + (patch_lengths != 0).sum(dim=-1) + ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + + cross_attn_mask_enc = None + # Cross-attention encoder + if self.cross_attn_encoder: + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=True, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_encoder, + block_mask=self.cross_attn_use_flex_attention, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, + ) + + # N-gram table embeddings + if self.encoder_ngram_embedding is not None: + assert ngram_ids is not None, "ngram_ids must be provided" + if local_encoder_embeds is None: + local_encoder_embeds = self.local_encoder.tok_embeddings( + local_encoder_tokens + ) + assert len(ngram_ids) == len( + self.encoder_ngram_embedding + ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}" + for i in range(ngram_ids.shape[0]): + ngram_embedding = self.encoder_ngram_embedding[i] + ngram_embeds = ngram_embedding(ngram_ids[i]) + assert ( + local_encoder_embeds.shape == ngram_embeds.shape + ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}" + local_encoder_embeds = local_encoder_embeds + ngram_embeds + + # Local encoder + h_cross = None + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=h_cross if self.cross_attn_encoder else None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + if not self.cross_attn_encoder: + assert ( + patch_ids.shape[1] == h_encoder.shape[1] + ), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}" + h = downsample( + h_encoder, + patch_lengths.shape[1], + patch_lengths, + patch_ids, + downsampling_by_pooling=self.downsampling_by_pooling, + patch_size=self.patch_size, + ) + else: + # Reshape h_cross + h = h_cross.view(bs, patch_lengths.shape[1], -1) + + # Global transformer + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.eos_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.eos_id + + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, + ) + + # Unpatching + dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] + + # Generate decoder patch IDs + decoder_patch_ids = decoder_patch_ids_from_lengths( + patch_lengths, nb_boe, local_decoder_tokens.shape[-1] + ) + assert ( + torch.max(decoder_patch_ids) + 1 <= h.shape[1] + ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + assert ( + decoder_patch_ids.shape[1] == dec_embeds.shape[1] + ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + + # Cross-attention decoder + if not self.cross_attn_decoder: + h = torch.gather( + h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + ) + cross_attn_mask_dec = None + assert local_decoder_tokens.shape == h.shape[:-1] + else: + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_decoder, + block_mask=self.cross_attn_use_flex_attention, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=dec_embeds, + patch_embeds=h, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + ) + return output + + def reset_parameters(self, init_std=None): + # Either use fixed base std or sqrt model dim + init_std = init_std or (self.dim ** (-0.5)) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + def init_weights(self): + self.reset_parameters() + self.init_base_std = self.init_base_std or (self.dim ** (-0.5)) + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(self.init_base_std, factor) + + self.local_decoder.init_weights(self.init_base_std) + self.global_transformer.init_weights(self.init_base_std) + self.local_encoder.init_weights(self.init_base_std) + + for emb in self.encoder_hash_tok_embedding: + nn.init.trunc_normal_( + emb.weight, + mean=0.0, + std=self.init_base_std, + a=-3 * self.init_base_std, + b=3 * self.init_base_std, + ) diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py new file mode 100644 index 0000000..8255504 --- /dev/null +++ b/bytelatent/model/local_models.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask +from xformers.ops import AttentionBias + +from bytelatent.base_transformer import ( + InitStdFactor, + RMSNorm, + RotaryEmbedding, + TransformerBlock, +) +from bytelatent.model.transformer import CrossAttention +from bytelatent.model.utils import create_causal_mask, downsample +from bytelatent.tokenizers.blt_tokenizer import BOE_ID + +logger = logging.getLogger() + + +class LocalModelBase(nn.Module): + def __init__(self, args): + super().__init__() + + self.dim = args.dim + self.dropout = args.dropout + self.vocab_size = args.vocab_size + args.pm_size + self.patch_size = args.patch_size + + self.efficient_attn = args.efficient_attn + self.sliding_window = args.sliding_window + self.use_rope = args.use_rope + self.init_std_factor = args.init_std_factor + self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) + self.cross_attn_k = getattr(args, "cross_attn_k", None) + + self.boe_id = BOE_ID + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.layers = nn.ModuleList( + [TransformerBlock(args) for _ in range(args.n_layers)] + ) + + self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + if not self.use_rope: + self.pos_embeddings = nn.Embedding(args.max_length, args.dim) + else: + self.rope = RotaryEmbedding( + theta=args.rope_theta, + head_dim=args.head_dim or args.dim // args.n_heads, + max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + ) + self.pos_embeddings = None + + self.token_embedding_projection = ( + nn.Linear(args.dim_token_emb, args.dim, bias=False) + if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(args) + + def _should_create_patch_projection(self, args): + dimension_mismatch = ( + getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim + ) + + # Check cross attention conditions + cross_attn_conditions = ( + hasattr(args, "cross_attn_encoder") + and args.cross_attn_encoder + and getattr(args, "cross_attn_init_by_pooling") + ) or ( + hasattr(args, "cross_attn_decoder") + and args.cross_attn_decoder + and getattr(args, "cross_attn_init_by_pooling") + ) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, args): + if not self._should_create_patch_projection(args): + return None + + output_dim = args.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=args.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + def init_weights(self, init_std=None): + self.rope.reset_parameters() + + init_std = init_std or (self.dim ** (-0.5)) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if self.pos_embeddings is not None: + nn.init.trunc_normal_( + self.pos_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(init_std, factor) + + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.patch_embedding_projection is not None: + nn.init.trunc_normal_( + self.patch_embedding_projection.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if hasattr(self, "output"): + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.cross_attn_layers is not None: + for depth, layer in enumerate(self.cross_attn_layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(init_std, factor) + + +class LocalEncoder(LocalModelBase): + def __init__(self, args): + super().__init__(args) + self.output_proj = ( + args.patching_mode in ["entropy", "probmax"] + ) and args.entropy_model_checkpoint_dir is None + + self.apply_transformer = args.use_local_encoder_transformer + self.downsampling_by_pooling = args.downsampling_by_pooling + self.patch_only = args.patch_only_encoder + self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = args.cross_attn_encoder + self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling + self.cross_attn_nheads = args.cross_attn_nheads + + if self.cross_attn_encoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + CrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=args.norm_eps, + ) + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert ( + self.expects_hash_embeddings + ), "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + + h = self.apply_embedding(tokens, embeds) + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and ( + i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder + ): + patch_embeds = self.apply_cross_attention( + h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask + ) + + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache + + def apply_cross_attention( + self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask + ): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = downsample( + h, + num_patches, + patch_ids=patch_ids, + downsampling_by_pooling=self.downsampling_by_pooling, + patch_size=self.patch_size, + ) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + patch_embeds += patch_embeds_cross + return patch_embeds + + +class LocalDecoder(LocalModelBase): + def __init__(self, args): + super().__init__(args) + + # Model configuration flags + self.patch_only = args.patch_only_decoder + self.expects_embeddings = args.share_encoder_decoder_emb + self.cross_attn_decoder = args.cross_attn_decoder + self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling + self.cross_attn_nheads = args.cross_attn_nheads + + if self.cross_attn_decoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + CrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=args.norm_eps, + ) + ) + + self.output = nn.Linear( + self.dim, + args.vocab_size, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + assert embeds is not None, "Embeddings must be provided" + + if mask is None: + mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + + h = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + if patch_embeds is not None and not self.cross_attn_decoder: + h = h + patch_embeds + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and ( + i == 0 or self.cross_attn_all_layers_decoder + ): + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, + ) + h = h + h_cross + + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache diff --git a/bytelatent/model/transformer.py b/bytelatent/model/transformer.py new file mode 100644 index 0000000..24dc057 --- /dev/null +++ b/bytelatent/model/transformer.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask +from xformers.ops import AttentionBias + +from bytelatent.base_transformer import ( + BaseTransformer, + RMSNorm, + flex_attention_comp, + repeat_kv, +) +from bytelatent.model.utils import create_causal_mask + +logger = logging.getLogger() + + +class CrossAttention(nn.Module): + """ + CrossAttention block to attend to the encoder states from the decoder. + Rope is not supported. + """ + + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + output = self.wo(output.reshape(output_shape)) + + return x + output + + def init_weights(self, base_std: float, factor: float = 1.0): + std = base_std * factor + + nn.init.trunc_normal_( + self.wq.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wk.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wv.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + output_std = std / (2**0.5) + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=output_std, + a=-3 * output_std, + b=3 * output_std, + ) + self.cross_attn_norm_q.reset_parameters() + self.cross_attn_norm_kv.reset_parameters() + + +class GlobalTransformer(BaseTransformer): + def __init__(self, args): + super().__init__(args) + self.dropout = args.dropout + self.sliding_window = args.sliding_window + self.efficient_attn = args.efficient_attn + + self.token_embedding_projection = None + if args.dim_token_emb is not None and args.dim_token_emb != self.dim: + self.token_embedding_projection = nn.Linear( + args.dim_token_emb, + args.dim, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ + Similar to BaseTransformer.forward, but with an additional embeds argument + and projection to the token space. + """ + bs, seqlen = tokens.shape + attn_impl = self.efficient_attn + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask(seqlen, attn_impl, self.sliding_window) + ) + + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) + + h = F.dropout(h, p=self.dropout, training=self.training) + + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + return h, cache + + def init_weights(self, init_base_std: float): + super().init_weights() + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=init_base_std, + a=-3 * init_base_std, + b=3 * init_base_std, + ) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py new file mode 100644 index 0000000..ce52a30 --- /dev/null +++ b/bytelatent/model/utils.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +from torch.nn.attention.flex_attention import create_block_mask +from xformers.ops import fmha + + +def patch_reduce(h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros( + (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device + ) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + + +def concat_downsample(h, patch_lengths, patch_size): + # The assumption in this function is that seq_len = patch_size * num_patches. + bs, seq_len, emb_dim = h.shape + patch_end_ids = torch.cumsum(patch_lengths, dim=1) + patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to( + patch_end_ids.device + ) + # Is clamp ok here? + patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1]) + patch_ids = patch_ids.view(bs, -1, emb_dim) + # after gather h.shape = [batch_size, seq_len, dim] + h = torch.gather(h, 1, patch_ids) + h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1)) + return h + + +def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids): + cat = [] + if "avg" in pooling_mode or "mean" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids)) + if "min" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids)) + if "max" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids)) + assert len(cat) > 0 + h = torch.cat(cat, dim=-1) + return h + + +def downsample( + h, + num_patches, + patch_lengths=None, + patch_ids=None, + downsampling_by_pooling=None, + patch_size=4, +): + """ + Downsampling: + a. concatenating embeddings in the patch + Note: with dynamic patching, patch the last patch_size tokens. + b. pooling embeddings in the patch + """ + # input: h.shape = [batch_size, seq_len, dim] + # input: pool h.shape = [batch_size, seq_len / patch_size, dim] + # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep + if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0: + # By pooling + max_num_patches = num_patches + assert patch_ids is not None + h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids) + else: + # TODO: remove this condition + # By concatenating (fixed lengths patching) + assert patch_lengths is not None + h = concat_downsample(h, patch_lengths, patch_size) + return h + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def create_causal_mask(seqlen, attn_impl, sliding_window): + if sliding_window is not None and attn_impl == "xformers": + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) + elif attn_impl == "xformers": + return fmha.attn_bias.LowerTriangularMask() + elif attn_impl == "sdpa": + return "causal" + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + elif attn_impl == "fmha": + return None + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) diff --git a/bytelatent/optim.py b/bytelatent/optim.py new file mode 100644 index 0000000..c6e1829 --- /dev/null +++ b/bytelatent/optim.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +import math +from functools import partial + +from pydantic import BaseModel, ConfigDict +from torch import nn +from torch.optim import AdamW, lr_scheduler + +logger = logging.getLogger() + + +class OptimArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + lr: float = 3e-4 + weight_decay: float = 0.1 + epsilon: float = 1e-8 + beta1: float = 0.9 + beta2: float = 0.95 + clip: float = 1.0 + + scheduler: str = "cosine" + warmup: int = 2000 + lr_min_ratio: float = 0.1 + cycle_length: float = 1.0 + cosine_theta: float = 1.0 + annealing_step: int = 1000 + decay_fraction: float = 0.1 + + exp_factor: float = 0.5 + + +def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float: + if step < warmup: + lr = float(step) / warmup + elif step <= n_steps: + s = float(step - warmup) / (n_steps - warmup) + lr = s * min_ratio + (1 - s) + else: + lr = min_ratio + return lr + + +def lr_inv_sqrt(step: int, warmup: int, exp_factor: float, min_ratio: float) -> float: + if step < warmup: + lr = float(step) / warmup + else: + lr = max((warmup**exp_factor) / (step**exp_factor), min_ratio) + return lr + + +def lr_cosine( + step: int, + warmup: int, + n_steps: int, + cycle_length: float, + theta: float, + min_ratio: float, +) -> float: + sign = ((step // (n_steps * cycle_length)) % 2) * -2 + 1 + if step < warmup: + lr = float(step) / warmup + elif step <= n_steps: + s = float(step - warmup) / (n_steps - warmup) + lr = min_ratio + 0.5 * (1 - min_ratio) * ( + sign * math.cos(math.pi * s**theta / cycle_length) + 1 + ) + else: + lr = min_ratio + return lr + + +def lr_wsd( + step: int, + warmup: int, + n_steps: int, + decay_fraction: float, + cycle_length: float, + min_ratio: float, +) -> float: + """ + UNDERSTANDING WARMUP-STABLE-DECAY LEARNING RATES: A RIVER VALLEY LOSS LANDSCAPE PERSPECTIVE + https://arxiv.org/pdf/2410.05192 + """ + cycle_num = step // int(n_steps * cycle_length) + 1 + curr_n_steps = int(n_steps * cycle_length) * cycle_num + decay_length = int(curr_n_steps * decay_fraction) + + if step < warmup: + lr = float(step) / warmup + elif step <= curr_n_steps - decay_length: + lr = 1.0 + elif step > curr_n_steps - decay_length and step <= curr_n_steps: + # Linear interpolation gives similar results + # slope = -(1.0 - min_ratio) / decay_length + # intercept = min_ratio + ((1.0 - min_ratio) * curr_n_steps) / decay_length + # lr = slope * step + intercept + + step = step - (curr_n_steps - decay_length) + lr = 1 / ((step / curr_n_steps) * (1 / min_ratio) + (1 - step / curr_n_steps)) + else: + lr = min_ratio + + return lr + + +def build_lr_fn(args: OptimArgs, n_steps: int): + if args.scheduler == "constant": + lr_fn = lambda x: 1.0 + elif args.scheduler == "linear": + lr_fn = partial( + lr_linear, warmup=args.warmup, n_steps=n_steps, min_ratio=args.lr_min_ratio + ) + elif args.scheduler == "inv_sqrt": + lr_fn = partial( + lr_inv_sqrt, + warmup=args.warmup, + exp_factor=args.exp_factor, + min_ratio=args.lr_min_ratio, + ) + elif args.scheduler == "cosine": + lr_fn = partial( + lr_cosine, + warmup=args.warmup, + n_steps=n_steps, + cycle_length=args.cycle_length, + theta=args.cosine_theta, + min_ratio=args.lr_min_ratio, + ) + elif args.scheduler == "wsd": + assert args.decay_fraction < args.cycle_length + lr_fn = partial( + lr_wsd, + warmup=args.warmup, + n_steps=n_steps, + decay_fraction=args.decay_fraction, + cycle_length=args.cycle_length, + min_ratio=args.lr_min_ratio, + ) + else: + raise NotImplementedError(f"Unknown scheduler: {args.scheduler}") + return lr_fn + + +def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int): + logger.info("Starting build of optimizer...") + optimizer = AdamW( + model.parameters(), + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + eps=args.epsilon, + fused=True, # Faster optim.step but can throw errors + ) + + # scheduler + lr_fn = build_lr_fn(args, n_steps) + scheduler = lr_scheduler.LambdaLR(optimizer, lr_fn) + + logger.info("Done with build of optimizer.") + return optimizer, scheduler diff --git a/bytelatent/plotting/__init__.py b/bytelatent/plotting/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/plotting/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/plotting/config_entropy_figure.yaml b/bytelatent/plotting/config_entropy_figure.yaml new file mode 100644 index 0000000..4d7bfd7 --- /dev/null +++ b/bytelatent/plotting/config_entropy_figure.yaml @@ -0,0 +1,3 @@ +data_path: plot_data/entropy_figure.json +chart_path: figures/entropy_figure.pdf +# chart_path: figures/entropy_figure.pdf diff --git a/bytelatent/plotting/config_scaling_figures.yaml b/bytelatent/plotting/config_scaling_figures.yaml new file mode 100644 index 0000000..cda85c2 --- /dev/null +++ b/bytelatent/plotting/config_scaling_figures.yaml @@ -0,0 +1,4 @@ +df_dir: /home/par/blt_df +output_chart_dir: figures/ +frame_files: + ["4b_df.json", "500m_df.json", "scaling_arch_df.json", "scaling_df.json"] diff --git a/bytelatent/plotting/entropy_figure.py b/bytelatent/plotting/entropy_figure.py new file mode 100644 index 0000000..c1966a1 --- /dev/null +++ b/bytelatent/plotting/entropy_figure.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +import sys +from pathlib import Path + +import altair as alt +import pandas as pd +from omegaconf import OmegaConf +from pydantic import BaseModel + + +class PlotEntropiesConfig(BaseModel): + data_path: str | None + chart_path: str + + class Config: + extra = "forbid" + + +class PlotEntropiesData(BaseModel): + text: str + threshold: float = 1.335442066192627 + dataframe_json: str | None + + class Config: + extra = "forbid" + + +def main(): + config_path = sys.argv[1] + file_config = OmegaConf.load(config_path) + # Omit program name and config file name + cli_conf = OmegaConf.from_cli(sys.argv[2:]) + conf_dict = OmegaConf.to_container( + OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True + ) + plot_config = PlotEntropiesConfig(**conf_dict) + with open(plot_config.data_path) as f: + json_data = f.read() + plot_data = PlotEntropiesData.model_validate_json(json_data) + df = pd.read_json(plot_data.dataframe_json) + + x_ticks = [] + for row in df.itertuples(): + position = row.position + token = row.tokens + x_ticks.append(f"{str(position).zfill(3)}|{token}") + df["position_with_token"] = x_ticks + print(df) + + x_axis = alt.Axis( + labelExpr="split(datum.label, '|')[1]", + grid=False, + labelOverlap=False, + labelAngle=0, + ) + width = 1200 + height = 150 + base = alt.Chart(df).properties(width=width, height=height) + points = base.mark_line(point=True).encode( + x=alt.X("position_with_token:O", title=None, axis=x_axis), + y=alt.Y( + "entropies", + title="Entropy of Next Byte", + ), + ) + rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( + y=alt.datum(plot_data.threshold), + ) + patch_rules = ( + alt.Chart(df[df["start"] > 0]) + .properties(width=width, height=height) + .mark_rule(color="#474747", strokeDash=[4, 2]) + .encode(x=alt.X("position_with_token:O", axis=x_axis)) + ) + + chart = patch_rules + rule + points + chart = chart.configure_axis(labelFontSize=15, titleFontSize=15) + path = Path(plot_config.chart_path) + path.parent.mkdir(exist_ok=True) + chart.save(path) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/plotting/scaling_figures.py b/bytelatent/plotting/scaling_figures.py new file mode 100644 index 0000000..a49624a --- /dev/null +++ b/bytelatent/plotting/scaling_figures.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import sys +from pathlib import Path + +import altair as alt +import pandas as pd +import pydantic +from omegaconf import OmegaConf + + +class ScalingPlotsConfig(pydantic.BaseModel): + df_dir: str + output_chart_dir: str + frame_files: list[str] + + class Config: + extra = "forbid" + + +def determine_family(key: str): + if key.startswith("Megabyte++"): + return "Megabyte++" + elif key.startswith("BLT"): + return "BLT" + elif key.startswith("LLaMA"): + return "LLaMA" + elif key.startswith("Space"): + return "Space" + + +file_to_vars = {} + + +def create_chart(df: pd.DataFrame, output_file: str): + df["metric"] = df["bpb/not_heldout.jsonl"] + df["family"] = df["key"].map(determine_family) + model_domain = [ + "BLT Space ps=6", + "BLT Space w/o cross-attn", + "SpaceByte", + "LLaMA 3 BPE", + "Megabyte++ ps=4", + "Megabyte++ ps=6", + ] + color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"] + shape_range = [ + "circle", + "square", + "cross", + "diamond", + "triangle-up", + "triangle-down", + ] + color_scale = alt.Scale(domain=model_domain, range=color_range) + shape_scale = alt.Scale( + domain=model_domain, + range=shape_range, + ) + base_chart = alt.Chart(df).encode( + x=alt.X("flops", title="Training FLOPS") + .scale(type="log", domain=[2e20, 1.25e22]) + .axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]), + y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False), + ) + lines = base_chart.encode( + color=alt.Color("key", title="Model Color", scale=color_scale, legend=None), + strokeDash=alt.StrokeDash("family", title="Model Family", legend=None), + ).mark_line() + points = base_chart.encode( + color=alt.Color("key", title="Model", scale=color_scale), + shape=alt.Shape("key", title="", scale=shape_scale), + ).mark_point(size=70) + chart = ( + (lines + points) + .resolve_scale( + color="independent", + shape="independent", + # strokeDash="independent", + ) + .configure_legend(orient="right") + .properties(height=300, width=400) + ) + print("Saving", output_file) + chart.save(output_file) + + +def main(): + config_path = sys.argv[1] + file_config = OmegaConf.load(config_path) + # Omit program name and config file name + cli_conf = OmegaConf.from_cli(sys.argv[2:]) + conf_dict = OmegaConf.to_container( + OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True + ) + plot_config = ScalingPlotsConfig(**conf_dict) + df_dir = Path(plot_config.df_dir) + chart_dir = Path(plot_config.output_chart_dir) + chart_dir.mkdir(exist_ok=True, parents=True) + for ff in plot_config.frame_files: + path = df_dir / ff + df = pd.read_json(path) + print(df) + print(df.columns) + create_chart(df, chart_dir / f"{path.name}.pdf") + + +if __name__ == "__main__": + main() diff --git a/bytelatent/preprocess/__init__.py b/bytelatent/preprocess/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/preprocess/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/preprocess/data_pipeline.py b/bytelatent/preprocess/data_pipeline.py new file mode 100644 index 0000000..7dd5283 --- /dev/null +++ b/bytelatent/preprocess/data_pipeline.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import subprocess +from pathlib import Path + +import luigi + +# CHANGEME: Change this to point to your data +BASE_DIR = Path("datasets") +DATASETS = ["dclm"] +TARGET_DIR = Path("entropy_preprocess") + +SHARD_SCRIPT = """split -C 2500m -d {source} {destination}.shard_""" + + +def list_dataset_shards(dataset: str): + dataset_dir = BASE_DIR / dataset + return list(dataset_dir.glob("*.chunk.*.jsonl")) + + +class ChunkFile(luigi.ExternalTask): + file = luigi.Parameter() + + def output(self): + return luigi.LocalTarget(self.file) + + +class ShardDatasetChunk(luigi.Task): + dataset_name = luigi.Parameter() + chunk_file = luigi.Parameter() + + def _chunk_filename(self): + return Path(self.chunk_file).name + + def requires(self): + return ChunkFile(self.chunk_file) + + def run(self): + destination_dir = TARGET_DIR / str(self.dataset_name) + destination_dir.mkdir(parents=True, exist_ok=True) + destination = destination_dir / self._chunk_filename() + subprocess.check_output( + SHARD_SCRIPT.format(source=str(self.chunk_file), destination=destination), + shell=True, + ) + ( + Path(TARGET_DIR) + / str(self.dataset_name) + / f"{self._chunk_filename()}.shard.COMPLETE" + ).touch() + + def output(self): + return luigi.LocalTarget( + TARGET_DIR + / str(self.dataset_name) + / f"{self._chunk_filename()}.shard.COMPLETE" + ) + + +class ShardDataset(luigi.WrapperTask): + dataset_name = luigi.Parameter() + + def requires(self): + for f in list_dataset_shards(self.dataset_name): + yield ShardDatasetChunk(dataset_name=self.dataset_name, chunk_file=str(f)) + + +class ShardAllDatasets(luigi.WrapperTask): + def requires(self): + for d in DATASETS: + yield ShardDataset(dataset_name=d) + + +if __name__ == "__main__": + luigi.build([ShardAllDatasets()], local_scheduler=True, workers=128) diff --git a/bytelatent/preprocess/parallel_entropies.py b/bytelatent/preprocess/parallel_entropies.py new file mode 100644 index 0000000..ac2dbd4 --- /dev/null +++ b/bytelatent/preprocess/parallel_entropies.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import subprocess +from pathlib import Path + +import submitit +import typer + + +class PreprocessEntropiesJob(submitit.helpers.Checkpointable): + def __init__(self) -> None: + pass + + def __call__(self, shard_file: str, output_filename: str): + subprocess.run( + [ + "python", + "-u", + "-m", + "bytelatent.preprocess.preprocess_entropies", + str(shard_file), + str(output_filename), + ], + check=True, + ) + return True + + +def chunk(items, size): + for i in range(0, len(items), size): + yield items[i : i + size] + + +def main( + job_folder: str, + input_dir: str, + output_dir: str, + qos: str = "explore", + slurm_batch_size: int = 1000, + check_only: bool = False, + wait: bool = False, +): + input_dir = Path(input_dir) + output_dir = Path(output_dir) + shard_files = [ + p for p in input_dir.glob("*.jsonl.shard*") if "COMPLETE" not in p.name + ] + if check_only: + exist = [] + missing = [] + for shard_file in shard_files: + shard_file = Path(shard_file) + complete_file = output_dir / f"{shard_file.name}.arrow.complete" + if complete_file.exists(): + exist.append(complete_file) + else: + missing.append(complete_file) + print("Checked for output files for input_dir=", input_dir) + print("Exist:", len(exist)) + print("Missing:", len(missing)) + print(missing) + return + print("Running parallel job over N files=", len(shard_files)) + print("Input Directory:", input_dir) + print("Output Directory:", output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + executor = submitit.SlurmExecutor(job_folder) + executor.update_parameters( + # 12 hours in minutes + time=60 * 12, + qos=qos, + exclusive="user", + cpus_per_task=4, + num_gpus=1, + mem_per_gpu="80G", + array_parallelism=slurm_batch_size, + ) + + jobs = [] + n_batches = 0 + n_skipped = 0 + n_launched = 0 + for file_batch in chunk(shard_files, slurm_batch_size): + with executor.batch(): + for shard_file in file_batch: + output_filename = Path(output_dir) / f"{shard_file.name}.arrow" + complete_output_filename = ( + Path(output_dir) / f"{shard_file.name}.arrow.complete" + ) + if complete_output_filename.exists(): + n_skipped += 1 + else: + job = executor.submit( + PreprocessEntropiesJob(), str(shard_file), str(output_filename) + ) + n_launched += 1 + jobs.append(job) + n_batches += 1 + print("launched array jobs n=", n_launched) + print("skipped (completed) array jobs n=", n_skipped) + print("number of slurm batches=", n_batches) + if wait: + output = [job.result() for job in jobs] + assert all(output) + + +if __name__ == "__main__": + typer.run(main) diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py new file mode 100644 index 0000000..20d1e0c --- /dev/null +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import time +from pathlib import Path + +import numpy as np +import pyarrow as pa +import torch +import typer +from rich.progress import Progress, TextColumn + +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator + + +def main( + input_file: str, + output_file: str, + patching_device: str = "cuda", + log_step: int = 10_000, + entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + dry_run: bool = False, +): + # TODO: Modify this to work with the new code + raise NotImplementedError() + iterator = ArrowFileIterator( + file_path=input_file, + worker_id=0, + num_workers=1, + ) + tokenization_mode = "bytes" + print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") + print("Loading entropy model", entropy_model_checkpoint_dir) + if dry_run: + return + entropy_model = load_entropy_model( + entropy_model_checkpoint_dir, device=patching_device + ) + entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") + patching_batch_size = 32 + print("Creating tokenizer") + tokenizer = Tokenizer( + model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", + tokenization_mode=tokenization_mode, + # BYTE_UNITS + vocab_size_unit_1=256, + bos=True, + eos=True, + bpe_delim=False, + # This isn't used, just stores a reference for other calls we don't use + patcher=None, + ) + step = 0 + print("starting") + start_time = time.time() + patch_time = 0 + entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False) + sample_id_field = pa.field("sample_id", pa.string(), nullable=False) + text_field = pa.field("text", pa.string(), nullable=False) + schema = pa.schema([sample_id_field, text_field, entropy_field]) + arrow_batch_size = 1_000 + + try: + with pa.OSFile(output_file, "wb") as sink: + with pa.ipc.new_file(sink, schema) as writer: + id_buffer = [] + entropies_buffer = [] + text_buffer = [] + with Progress( + *Progress.get_default_columns(), + TextColumn("Completed: {task.completed}"), + ) as progress: + task = progress.add_task( + "[green]Calculating entropies...", total=None + ) + for doc in iterator: + sample_id = get_id_from_doc(doc) + + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError( + f"Could not find a text key from: {doc.keys()}" + ) + tokens = torch.tensor(tokenizer.encode(text)) + patch_start = time.time() + scores = calculate_entropies( + tokens, + entropy_model, + patching_batch_size, + patching_device, + ) + entropies_buffer.append( + np.array(scores.tolist(), dtype=np.float16) + ) + id_buffer.append(sample_id) + text_buffer.append(text) + if len(entropies_buffer) == arrow_batch_size: + batch = pa.record_batch( + { + "entropies": entropies_buffer, + "sample_id": id_buffer, + "text": text_buffer, + }, + schema, + ) + writer.write(batch) + entropies_buffer = [] + id_buffer = [] + text_buffer = [] + patch_time += time.time() - patch_start + step += 1 + if step % log_step == 0: + print("Completed steps:", step) + progress.update(task, advance=1) + if len(entropies_buffer) > 0: + # Write last things + batch = pa.record_batch( + { + "entropies": entropies_buffer, + "sample_id": id_buffer, + "text": text_buffer, + }, + schema, + ) + writer.write(batch) + entropies_buffer = [] + id_buffer = [] + text_buffer = [] + Path(f"{output_file}.complete").touch() + except: + Path(output_file).unlink(missing_ok=True) + raise + elapsed = time.time() - start_time + print("steps", step) + print("done in:", elapsed) + + +if __name__ == "__main__": + typer.run(main) diff --git a/bytelatent/probe.py b/bytelatent/probe.py new file mode 100644 index 0000000..7bdc03c --- /dev/null +++ b/bytelatent/probe.py @@ -0,0 +1,694 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This file from the xFormers repo is just a example of how to implement +# probing of the activations of a model, without changing anything. +# By default, the linear inputs/outputs/gradients are logged, as well as +# the attention logits+entropy. It is possible to log an additional tensor, eg: +# x = log_stats(x, "name") +# +# Known limitations: +# * Only a subset of the attention biases is supported +# * Torch-compile is disabled automatically when this is enabled +# * Only tested with bf16/f16/f32 datatypes + +import contextlib +import functools +import json +import math +import os +import uuid +from collections import defaultdict +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + checkpoint_wrapper, +) +from torch.fx.operator_schemas import normalize_function +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map +from torch.utils.module_tracker import ModuleTracker +from xformers.ops import fmha + + +@torch.library.custom_op("torchprobe::log", mutates_args=(), device_types=None) +def _log(x: torch.Tensor, name: str, uid: str) -> None: + pass + + +@_log.register_fake +def _log_fake(x: torch.Tensor, name: str, uid: str) -> None: + pass + + +class _LogStats(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, name: str): + uid = str(uuid.uuid4()) + torch.ops.torchprobe.log(x, name, uid) + ctx.name = name + ctx.uid = uid + return x + + @staticmethod + def backward(ctx, grad: torch.Tensor): + torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid) + return grad, None + + +_PROBING_ENABLED = False + + +def log_stats(x: torch.Tensor, name: str) -> torch.Tensor: + if not _PROBING_ENABLED: + return x + return _LogStats.apply(x, name) + + +QUANTILES = [ + 0.0000001, + 0.000001, + 0.00001, + 0.0001, + 0.001, + 0.01, + 0.05, + 0.1, + 0.3, + 0.5, + 0.7, + 0.9, + 0.95, + 0.99, + 0.999, + 0.9999, + 0.99999, + 0.999999, + 0.9999999, +] + + +@functools.cache +def _get_quantiles(device: torch.device, dtype) -> torch.Tensor: + return torch.tensor(QUANTILES, device=device, dtype=dtype) + + +def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]: + if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]: + return {} + x = x_.flatten() + if remove_inf: + x = x[x.abs() < float("inf")] + if x.dtype is not torch.double: + x = x.float() + xabs = x.abs() + quantiles = _get_quantiles(x.device, x.dtype) + mean = x.mean() + std = x.std() + return { + "shape": tuple(x_.shape), + "mean": mean, + "std": std, + "skew": (((x - mean) / std) ** 3).double().mean(), + "kurtosis": (((x - mean) / std) ** 4).double().mean(), + "abs.mean": xabs.mean(), + "max": x.max(), + "min": x.min(), + # Note: `quantile` takes at most 2**24 elements, see + # https://github.com/pytorch/pytorch/issues/64947 + "quantiles": torch.quantile(x[: 2**24], quantiles), + } + + +def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None: + assert logits.ndim == 4 + logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf + + +def _mask_attn_logits( + logits: torch.Tensor, + q_idx: List[int], + *, + causal: bool, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert logits.dtype is torch.float32 + # Handle BlockDiagonalMask + if cu_seqlens_q is not None: + assert cu_seqlens_k is not None + # Expect BHMqMkv + assert logits.ndim == 4, logits.shape + qs = cu_seqlens_q.tolist() + ks = cu_seqlens_k.tolist() + q_batchid = [] + k_batchid = [-2] * logits.shape[-1] + q_idx_i = 0 + for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])): + for k in range(k0, k1): + k_batchid[k] = bid + while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1: + q_batchid.append(bid) + if causal: + _mask_attn_causal_inplace( + logits[:, :, q_idx_i : q_idx_i + 1, k0:k1], + q_idx[q_idx_i] - q0, + q1 - q0, + k1 - k0, + ) + q_idx_i += 1 + mask_out = ( + torch.tensor(q_batchid, device=logits.device)[None, None, :, None] + != torch.tensor(k_batchid, device=logits.device)[None, None, None, :] + ) + logits[mask_out.expand_as(logits)] = -math.inf + assert q_idx_i == len(q_idx) + elif causal: + for q_idx_i in range(len(q_idx)): + _mask_attn_causal_inplace( + logits[:, :, q_idx_i : q_idx_i + 1, :], + q_idx[q_idx_i], + logits.shape[2], + logits.shape[3], + ) + return logits + + +def _attn_queries_subset(num_queries: int) -> List[int]: + return list(range(0, num_queries, max(1, num_queries // 128))) + + +@torch.no_grad() +def _compute_attn_stats_sdpa( + probe, + path: str, + # supports arguments both cudnn + flash backends + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask=None, + attn_bias=None, + dropout_p=0.0, + is_causal=False, + scale=None, + compute_log_sumexp=True, + return_debug_mask=False, + **kwargs, +): + if scale is None: + scale = 1 / (query.shape[-1] ** 0.5) + # Filter-out not supported cases + if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs: + probe.store[f"{path}::attn"] = { + "query.shape": tuple(query.shape), + "key.shape": tuple(key.shape), + "value.shape": tuple(value.shape), + "attn_mask": attn_mask.shape if attn_mask is not None else None, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "unk_kwargs": list(kwargs.keys()), + } + return + # Take a subset of the queries and compute the logits + query_s = _attn_queries_subset(query.shape[-2]) + logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale + logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal) + p = logits.float().softmax(-1) + masked_logsoft = logits.log_softmax(-1).where( + (logits > -math.inf), torch.zeros_like(logits) + ) + entropy = -(p * masked_logsoft).sum(-1) + probe.log_tensor(f"{path}::attn_entropy", entropy) + probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) + + +@torch.no_grad() +def _compute_attn_stats_flash( + probe, + path: str, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + p: float, + softmax_scale: float, + is_causal: bool, + window_left: int, + window_right: int, + return_softmax: bool, + block_tables: Optional[torch.Tensor], + unpadded_lse: bool = False, +) -> None: + # Filter-out not supported cases + if ( + seqused_k is not None + or p != 0.0 + or window_left >= 0 + or window_right >= 0 + or block_tables is not None + ): + probe.store[f"{path}::attn"] = { + "query.shape": tuple(query.shape), + "key.shape": tuple(key.shape), + "value.shape": tuple(value.shape), + "op": "flash", + } + return + + if cu_seqlens_q is not None: + assert query.ndim == 3, query.shape + query, key, value = query[None], key[None], value[None] + assert query.ndim == 4, query.shape + + # Take a subset of the queries and compute the logits + query_s = _attn_queries_subset(query.shape[1]) + logits = ( + query[:, query_s].transpose(1, 2) + @ key.transpose(1, 2).transpose(-1, -2) + * softmax_scale + ) + logits = _mask_attn_logits( + logits.float(), + query_s, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + causal=is_causal, + ) + p = logits.float().softmax(-1) + masked_logsoft = logits.log_softmax(-1).where( + (logits > -math.inf), torch.zeros_like(logits) + ) + entropy = -(p * masked_logsoft).sum(-1) + probe.log_tensor(f"{path}::attn_entropy", entropy) + probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) + + +def _tensors_to_python(x): + if not isinstance(x, torch.Tensor): + return x + return x.tolist() + + +# class syntax +class LinearBwType(Enum): + DW = 1 + DX = 2 + UNKNOWN = 3 + + +class AutoProbeD(TorchDispatchMode): + def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None: + self.write_file = Path(write_file) if write_file is not None else None + self.write_tensors_tmpdir: Optional[Path] = None + self.compile_disabler = TorchCompileDisabler(module) + self.mod_tracker = ModuleTracker() + self.count_per_path: Dict[str, int] = defaultdict(int) + self.store: Dict[str, Dict[str, Any]] = {} + self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {} + self.uid_to_path: Dict[str, str] = {} + self.metadata: Any = None + self.enabled = False + self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0"))) + + def __enter__(self): + global _PROBING_ENABLED + assert not self.enabled, "Entered probe twice" + self.compile_disabler.__enter__() + self.mod_tracker.__enter__() + super().__enter__() + self.enabled = True + _PROBING_ENABLED = True + # self._setup_tensors_logging() + return self + + def __exit__(self, *args) -> None: + global _PROBING_ENABLED + assert self.enabled, "Exiting probe without entering it" + super().__exit__(*args) + self.mod_tracker.__exit__(*args) + self.compile_disabler.__exit__(*args) + self._flush_and_clear() + _PROBING_ENABLED = False + self.enabled = False + + def _setup_tensors_logging(self): + if self.write_file is not None: + self.write_file.parent.mkdir(exist_ok=True) + self.write_tensors_tmpdir = ( + self.write_file.parent + / f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}" + ) + self.write_tensors_tmpdir.mkdir(exist_ok=True) + + def _flush_and_clear(self) -> None: + if self.write_file is not None: + dump_data = tree_map(_tensors_to_python, self.store) + with self.write_file.open("a") as fd: + json.dump( + { + "data": dump_data, + "meta": self.metadata, + "version": 2, + "quantiles": QUANTILES, + }, + fd, + ) + fd.write("\n") + if self.write_tensors_tmpdir is not None: + assert self.write_file is not None + dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump" + dump_dir.mkdir(exist_ok=True) + dir_name = "" + if "it" in self.metadata: + dir_name = f"it{int(self.metadata['it']):010}" + if dir_name == "" or (dump_dir / dir_name).exists(): + num_files = len(list(dump_dir.glob(f"{dir_name}v*"))) + dir_name = f"{dir_name}v{num_files}" + dump_dir = dump_dir / dir_name + assert not dump_dir.exists() + self.write_tensors_tmpdir.rename(dump_dir) + self.write_tensors_tmpdir = None + self.store.clear() + self.count_per_path.clear() + self.uid_to_path.clear() + + def _find_bw_path_and_type( + self, path: str, out: torch.Tensor, args + ) -> Tuple[str, LinearBwType]: + """ + We are in the BW pass, and process a GEMM. + Let's figure out: + (1) The path for the FW pass (might differ in case of ModuleTracker bug) + (2) The type of BW pass (eg `dw` or `dx`) + """ + + def _is_path_correct_dw(path: str) -> bool: + # dW.t = dY.t @ X + in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] + return out.shape == (w_shape[1], w_shape[0]) and torch.allclose( + input_sm, args[1][:4, :4] + ) + + def _is_path_correct_dx(path: str) -> bool: + # dX = dY @ W.t + in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] + return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4]) + + if path in self.linear_data: + if _is_path_correct_dw(path): + return path, LinearBwType.DW + if _is_path_correct_dx(path): + return path, LinearBwType.DX + for candidate_path in self.mod_tracker.parents: + if candidate_path not in self.linear_data: + continue + if _is_path_correct_dw(candidate_path): + return candidate_path, LinearBwType.DW + if _is_path_correct_dx(candidate_path): + return candidate_path, LinearBwType.DX + return path, LinearBwType.UNKNOWN + + def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None: + self.store[name] = _get_stats(x, **kwargs) + if self.write_tensors_tmpdir is not None: + name_safe = name.replace("::", "__").replace("/", "") + torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl") + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + path = None + # Find longest path + for p in self.mod_tracker.parents: + if p == "Global": + continue + if path is None or len(p) > len(path): + path = p + if path is None: + path = "Global" + path = path.replace("._checkpoint_wrapped_module", "") + out = func(*args, **kwargs) + + # Handle linear layers + if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]: + weight: torch.Tensor + input: torch.Tensor + if not self.mod_tracker.is_bw: + # (technically, weight is transposed) + if func._overloadpacket == torch.ops.aten.addmm: + _bias, input, weight = args[:3] + else: + assert func._overloadpacket == torch.ops.aten.mm + input, weight = args[:2] + self.log_tensor(f"{path}::in", input) + self.log_tensor(f"{path}::w", weight) + self.log_tensor(f"{path}::out", out) + self.linear_data[path] = ( + input.shape, + weight.shape, + out.shape, + input[:4, :4].clone(), + weight[:4, :4].T.clone(), + ) + elif func._overloadpacket == torch.ops.aten.mm: + # XXX: Try to find the actual path for the linear layer + # This is messed with with Francisco's FSDP sometimes + new_path, bwtype = self._find_bw_path_and_type(path, out, args) + if new_path != path: + if self.verbose: + print(f"E: Fixing path `{path}` -> `{new_path}") + path = new_path + + if bwtype == LinearBwType.DW: + # dW.t = dY.t @ X + self.log_tensor(f"{path}::w.g", out) + elif bwtype == LinearBwType.DX: + # dX = dY @ W.t + self.log_tensor(f"{path}::in.g", out) + self.log_tensor(f"{path}::out.g", args[0]) + elif func._overloadpacket in [ + torch.ops.aten._scaled_dot_product_flash_attention, + torch.ops.aten._scaled_dot_product_cudnn_attention, + ]: + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + _compute_attn_stats_sdpa(self, path, **kwargs) + elif func._overloadpacket == fmha.flash.FwOp.OPERATOR: + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + _compute_attn_stats_flash(self, path, **kwargs) + elif func._overloadpacket == torch.ops.torchprobe.log: + uid = args[2] + path = self.uid_to_path.setdefault(uid, path) + self.log_tensor(f"{path}::{args[1]}", args[0]) + if self.verbose: + print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}") + return out + + +def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None: + if module._compiled_call_impl is not None: + out.append(module) + for c in module.children(): + _find_all_submodules_compiled(out, module=c) + + +class TorchCompileDisabler: + def __init__(self, module: nn.Module) -> None: + self.module = module + self.submodules_compiled: List[nn.Module] = [] + self.compiled_call_impl: List[Any] = [] + self.disable_compile = torch.compiler.disable() + torch._dynamo.config.raise_on_ctx_manager_usage = False # type: ignore + + def __enter__(self) -> None: + # Remove all `_compiled_call_impl` attributes to effectively + # "undo" compilation + self.submodules_compiled.clear() + _find_all_submodules_compiled(self.submodules_compiled, self.module) + self.compiled_call_impl = [ + m._compiled_call_impl for m in self.submodules_compiled + ] + for m in self.submodules_compiled: + m._compiled_call_impl = None + self.disable_compile.__enter__() # type: ignore + + def __exit__(self, *args) -> None: + self.disable_compile.__exit__(*args) # type: ignore + for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl): + m._compiled_call_impl = c_impl + self.compiled_call_impl = [] + + +Probe = AutoProbeD + +# EXAMPLE USAGE +d = 512 +seqlen = 4 +bs = 2 + + +class Attention1(nn.Module): + def forward(self, x): + attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask() + return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape( + [x.shape[0], seqlen, -1] + ) + + +class Attention2(nn.Module): + def forward(self, x): + attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens( + [seqlen] * bs + ).make_causal() + xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]]) + return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape( + [x.shape[0], seqlen, -1] + ) + + +class AttentionSDPA(nn.Module): + def __init__(self): + super().__init__() + self.wo = nn.Linear(d, d) + + def forward(self, x): + x = x.transpose(1, 2) + return self.wo( + F.scaled_dot_product_attention(x, x, x) + .transpose(1, 2) + .reshape([x.shape[0], seqlen, -1]) + ) + + +class AttentionSDPAFlash(AttentionSDPA): + def forward(self, x): + x = x.transpose(1, 2) + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + return self.wo( + F.scaled_dot_product_attention(x, x, x) + .transpose(1, 2) + .reshape([x.shape[0], seqlen, -1]) + ) + + +class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.head = nn.Linear(d, 16) + self.trunk = nn.Sequential( + nn.Linear(d, d), + nn.Linear(d, d), + ) + self.q_proj = nn.Linear(d, d, bias=False) + self.trunk.compile() + self.attn1 = Attention1() + self.attn2 = Attention2() + self.attnSDPA = AttentionSDPA() + self.attnSDPAflash = AttentionSDPAFlash() + + def forward(self, x): + B, nHeads, D = x.shape[0], d // 64, 64 + x = self.q_proj(x).reshape([B, seqlen, nHeads, D]) + x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x) + x = log_stats(x, "attns_out") + return self.head(self.trunk(x)) + + +def test_masking() -> None: + q_seqlen = [1, 1, 14, 12] + kv_seqlen = [2, 2, 14, 18] + attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen, kv_seqlen + ).make_causal_from_bottomright() + logits = torch.randn( + [1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda" + ) + bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device) + logits_masked = logits.clone() + _mask_attn_logits( + logits_masked, + list(range(logits.shape[2])), + causal=True, + cu_seqlens_q=attn_bias.q_seqinfo.seqstart, + cu_seqlens_k=attn_bias.k_seqinfo.seqstart, + ) + assert (logits + bias == logits_masked).all().item() + + +def test_toy_model() -> None: + # Test masking + kw = dict(device="cuda", dtype=torch.float16) + x = torch.randn([bs, seqlen, d], **kw) + m = Model() + m.head = checkpoint_wrapper( + m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False + ) + m.to(**kw) + m.compile() + optim = torch.optim.SGD(m.parameters(), lr=0.0) + probe = AutoProbeD(m, "./probe.json") + + for i in range(4): + with contextlib.ExitStack() as stack: + print(f"########### STEP {i}") + if i % 4 == 1: + stack.enter_context(probe) + probe.metadata = {"it": i} + y = m(x) + g = torch.randn_like(y) + y.backward(g) + if i % 4 == 1: + assert probe.enabled + # Make sure we registered all linears + print(list(probe.store.keys())) + for key in [ + "Model::attns_out", + "Model::attns_out.g", + "Model.attn1::attn_logits", + "Model.attn2::attn_logits", + "Model.attnSDPA::attn_logits", + "Model.attnSDPAflash::attn_logits", + "Model.head::w", + "Model.head::w.g", + "Model.head::in", + "Model.head::in.g", + "Model.head::out", + "Model.head::out.g", + "Model.trunk.0::in", + "Model.trunk.1::in", + ]: + assert key in probe.store, f"Missing key: '{key}'" + # .. and that the values are correct + for key, tensor in [ + ("Model.head::w", m.head.weight), + ("Model.head::w.g", m.head.weight.grad), + ("Model.q_proj::in", x), + ("Model.q_proj::w.g", m.q_proj.weight.grad), + ("Model.head::out", y), + ("Model.head::out.g", g), + ]: + assert key in probe.store, f"Missing key: '{key}'" + assert torch.allclose( + probe.store[key]["abs.mean"], tensor.float().abs().mean() + ), f"'{key}' mismatches" + # Check we don't have `nans` + for key, value in probe.store.items(): + if "abs.mean" in value: + assert math.isfinite( + value["abs.mean"].item() + ), f"Inf/Nan for {key}" + optim.step() + optim.zero_grad() diff --git a/bytelatent/profiling.py b/bytelatent/profiling.py new file mode 100644 index 0000000..da3c90d --- /dev/null +++ b/bytelatent/profiling.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import contextlib +import logging +import os +from pathlib import Path + +import torch.distributed +import wandb +import xformers.profiler +from pydantic import BaseModel +from torch.profiler.profiler import profile +from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler + +from bytelatent.distributed import get_is_master + + +class ProfilerArgs(BaseModel): + run: bool = False + trace_folder: str = "profiling" + mem_warmup: int = 100 + mem_steps: int = 2 + profile_warmup: int = 102 + profile_steps: int = 2 + + +logger = logging.getLogger() + + +def perfetto_to_html(json_file, html_file): + import gzip + import string + + import viztracer + + root = os.path.dirname(viztracer.__file__) + sub = {} + json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file) + with open( + os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8" + ) as f: + tmpl = f.read() + with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f: + sub["trace_viewer_full"] = f.read() + with json_file as j: + content = j.read() + if isinstance(content, bytes): + content = content.decode("utf-8") + sub["json_data"] = content.replace("", "<\\/script>") # type: ignore + with open(html_file, "w+", encoding="utf-8") as output_file: + output_file.write(string.Template(tmpl).substitute(sub)) + + +class PyTorchProfilerWandb(PyTorchProfiler): + def __init__(self, main_profiler) -> None: + self.main_profiler = main_profiler + self.num_steps = 0 + self.pytorch_profiler = torch.profiler.profile( + on_trace_ready=self._on_trace, + profile_memory=True, + record_shapes=True, + # With stack gives huge profile traces + # and bugs out because of some non ascii + # character somewhere in pytorch + with_stack=False, + with_flops=True, + activities=self.ACTIVITIES, + ) + + def _analyze_trace(self, prof: profile): + logger.info("Begin analyze trace") + super()._analyze_trace(prof) + logger.info("End analyze trace") + + def _on_trace(self, prof: torch.profiler.profiler.profile) -> None: + super()._on_trace(prof) + if get_is_master() and wandb.run is not None: + filename = list( + Path(self.main_profiler.output_dir).glob( + "profile_CPU_CUDA*/*.pt.trace.json*" + ) + )[0] + html_path = str(filename).replace(".json", ".html") + perfetto_to_html(filename, html_path) + wandb.log({"profile_trace": wandb.Html(html_path)}) + + +class MemSnapshotsProfilerWandb(MemSnapshotsProfiler): + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + if get_is_master() and wandb.run is not None: + filename = list( + Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html") + )[0] + wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)}) + + +@contextlib.contextmanager +def maybe_run_profiler(dump_dir, module, config: ProfilerArgs): + # get user defined profiler settings + + if config.run: + trace_dir = os.path.join(dump_dir, config.trace_folder) + + logger.info(f"Profiling active. Traces will be saved at {trace_dir}") + + if get_is_master() and not os.path.exists(trace_dir): + os.makedirs(trace_dir) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + with xformers.profiler.profile( + output_dir=trace_dir, + module=module, + schedule=[ + ( + MemSnapshotsProfilerWandb, + config.mem_warmup, + config.mem_warmup + config.mem_steps, + ), + ( + PyTorchProfilerWandb, + config.profile_warmup, + config.profile_warmup + config.profile_steps, + ), + ], + ) as profiler: + yield profiler + + else: + torch_profiler = contextlib.nullcontext() + yield None diff --git a/bytelatent/stool.py b/bytelatent/stool.py new file mode 100644 index 0000000..965f4cb --- /dev/null +++ b/bytelatent/stool.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +import os +import shutil +import subprocess +from dataclasses import dataclass +from typing import Any, Dict + +from omegaconf import OmegaConf + + +@dataclass +class StoolArgs: + config: Any = None + launcher: str = "sbatch" # Can be sbatch or bash if already in salloc + script: str = "apps.main.train" # The script to run. + copy_code: bool = True # Wether to copy code to dump dir + dirs_exists_ok: bool = ( + False # Wether to copy new code and config and run regardless that dir exists + ) + override: bool = False # Wether to delete dump dir and restart + nodes: int = -1 # The number of nodes to run the job on. + ngpu: int = 8 # The number of GPUs required per node. + ncpu: int = 16 # The number of CPUs allocated per GPU. + mem: str = "" # The amount of memory to allocate. + anaconda: str = "default" # The path to the anaconda environment. + constraint: str = "" # The constraint on the nodes. + exclude: str = "" # The nodes to exclude. + time: int = -1 # The time limit of the job (in minutes). + account: str = "" + qos: str = "" + partition: str = "learn" + stdout: bool = False + + +SBATCH_COMMAND = """#!/bin/bash + +{exclude} +{qos} +{account} +{constraint} +#SBATCH --job-name={name} +#SBATCH --nodes={nodes} +#SBATCH --gres=gpu:{ngpus} +#SBATCH --cpus-per-gpu={ncpu} +#SBATCH --time={time} +#SBATCH --partition={partition} +#SBATCH --mem={mem} + +#SBATCH --output={dump_dir}/logs/%j/%j.stdout +#SBATCH --error={dump_dir}/logs/%j/%j.stderr + +#SBATCH --open-mode=append +#SBATCH --signal=USR2@120 +#SBATCH --distribution=block + +# Mimic the effect of "conda init", which doesn't work for scripts +eval "$({conda_exe} shell.bash hook)" +source activate {conda_env_path} + +{go_to_code_dir} + +export OMP_NUM_THREADS=1 +export LAUNCH_WITH="SBATCH" +export DUMP_DIR={dump_dir} +srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml +""" + + +def copy_dir(input_dir: str, output_dir: str) -> None: + print(f"Copying : {input_dir}\n" f"to : {output_dir} ...") + assert os.path.isdir(input_dir), f"{input_dir} is not a directory" + assert os.path.isdir(output_dir), f"{output_dir} is not a directory" + rsync_cmd = ( + f"rsync -arm --copy-links " + f"--include '**/' " + f"--include '*.py' " + f"--exclude='*' " + f"{input_dir}/ {output_dir}" + ) + print(f"Copying command: {rsync_cmd}") + subprocess.call([rsync_cmd], shell=True) + print("Copy done.") + + +def retrieve_max_time_per_partition() -> Dict[str, int]: + # retrieve partition max times (a bit slow) + + sinfo = json.loads(subprocess.check_output("sinfo --json", shell=True))["sinfo"] + max_times: Dict[str, int] = {} + + for info in sinfo: + if info["partition"]["maximums"]["time"]["infinite"]: + max_times[info["partition"]["name"]] = 14 * 24 * 60 # 14 days + else: + max_times[info["partition"]["name"]] = info["partition"]["maximums"][ + "time" + ][ + "number" + ] # in minutes + + return max_times + + +def validate_args(args) -> None: + # Set maximum time limit if not specified + if args.time == -1: + max_times = retrieve_max_time_per_partition() + args.time = max_times.get( + args.partition, 3 * 24 * 60 + ) # Default to 3 days if not found + print( + f"No time limit specified, using max time for partitions: {args.time} minutes" + ) + + if args.constraint: + args.constraint = f"#SBATCH --constraint={args.constraint}" + + if args.account: + args.account = f"#SBATCH --account={args.account}" + + if args.qos: + args.qos = f"#SBATCH --qos={args.qos}" + + if getattr(args, "exclude", ""): + args.exclude = f"#SBATCH --exclude={args.exclude}" + + if hasattr(args, "anaconda") and args.anaconda: + if args.anaconda == "default": + args.anaconda = ( + subprocess.check_output("which python", shell=True) + .decode("ascii") + .strip() + ) + else: + args.anaconda = f"{args.anaconda}/bin/python" + assert os.path.isfile(args.anaconda) + + args.mem = args.mem or "0" + + assert args.partition + assert args.ngpu > 0 + assert args.ncpu > 0 + assert args.nodes > 0 + assert args.time > 0 + assert args.partition + + +def launch_job(args: StoolArgs): + # Set up args default and validate them depending on the cluster or partition requested + validate_args(args) + dump_dir = args.config["dump_dir"] + job_name = args.config["name"] + print("Creating directories...") + os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override) + if args.override: + confirm = input( + f"Are you sure you want to delete the directory '{dump_dir}'? This action cannot be undone. (yes/no): " + ) + if confirm.lower() == "yes": + shutil.rmtree(dump_dir) + print(f"Directory '{dump_dir}' has been deleted.") + else: + print("Operation cancelled.") + return + if args.copy_code: + os.makedirs(f"{dump_dir}/code", exist_ok=args.dirs_exists_ok) + print("Copying code ...") + copy_dir(os.getcwd(), f"{dump_dir}/code") + + print("Saving config file ...") + with open(f"{dump_dir}/base_config.yaml", "w") as cfg: + cfg.write(OmegaConf.to_yaml(args.config)) + + conda_exe = os.environ.get("CONDA_EXE", "conda") + conda_env_path = os.path.dirname(os.path.dirname(args.anaconda)) + log_output = ( + "-o $DUMP_DIR/logs/%j/%j_%t.out -e $DUMP_DIR/logs/%j/%j_%t.err" + if not args.stdout + else "" + ) + sbatch = SBATCH_COMMAND.format( + name=job_name, + script=args.script, + dump_dir=dump_dir, + nodes=args.nodes, + tasks=args.nodes * args.ngpu, + nodes_per_run=args.nodes, + ngpus=args.ngpu, + ncpu=args.ncpu, + mem=args.mem, + qos=args.qos, + account=args.account, + constraint=args.constraint, + exclude=args.exclude, + time=args.time, + partition=args.partition, + conda_exe=conda_exe, + conda_env_path=conda_env_path, + log_output=log_output, + go_to_code_dir=f"cd {dump_dir}/code/" if args.copy_code else "", + ) + + print("Writing sbatch command ...") + with open(f"{dump_dir}/submit.slurm", "w") as f: + f.write(sbatch) + + print("Submitting job ...") + os.system(f"{args.launcher} {dump_dir}/submit.slurm") + + print("Done.") + + +if __name__ == "__main__": + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + mode: LMTransformerArgs + + @dataclass + class LMTransformerArgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgs + or just name=tictac for top level attributes. + """ + raise NotImplementedError("Update this to blt code") + args = OmegaConf.from_cli() + args.config = OmegaConf.load(args.config) + args = dataclass_from_dict(StoolArgs, args) + launch_job(args) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py new file mode 100644 index 0000000..73ad9f7 --- /dev/null +++ b/bytelatent/test_blt.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +from dataclasses import replace + +import numpy as np +import pytest +import torch + +from bytelatent.constants import BLT_DATA +from bytelatent.data.data_types import Batch +from bytelatent.data.ngram_processor import NgramProcessor +from bytelatent.model.blt import ( + ByteLatentTransformer, + ByteLatentTransformerArgs, + EmbeddingType, + compute_hash_embeddings, + create_global_transformer, + create_local_decoder, + create_local_encoder, + cross_attn_mask, + decoder_patch_ids_from_lengths, + get_blt_input, + init_embeddings, + patch_ids_from_lengths, +) +from bytelatent.model.transformer import CrossAttention +from bytelatent.model.utils import create_causal_mask +from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.train import compute_loss + + +def batch_to_tensors_and_gpu(batch): + x = torch.from_numpy(batch.x) + y = torch.from_numpy(batch.y) + mask = None if batch.mask is None else torch.from_numpy(batch.mask) + patch_lengths = ( + None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths) + ) + ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids) + + if torch.cuda.is_available(): + x = x.cuda() + y = y.cuda() + if mask is not None: + mask = mask.cuda() + if patch_lengths is not None: + patch_lengths = patch_lengths.cuda() + if ngram_ids is not None: + ngram_ids = ngram_ids.cuda() + return x, y, mask, patch_lengths, ngram_ids + + +def fake_batch(): + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + del batch_dict["x2"] + del batch_dict["y2"] + del batch_dict["src_names"] + return Batch(**batch_dict) + + +def create_args(cross_attention=False): + transformer_args = ByteLatentTransformerArgs( + # Base args provided + n_heads=8, + dim=512, + vocab_size=260, + # Additional args from command line + dim_token=256, + patch_size=6, + tokenization_mode="bytes", + patching_mode="space", + tie_local_encoder_decoder_logits=False, + data_loader_patching=True, + max_encoder_seq_length=12288, + pad_to_max_length=True, + encoder_lm_loss=False, + patching_threshold=3.1439168453216553, + encoder_hash_byte_group_size=[4], + encoder_hash_byte_group_vocab=50002, + encoder_hash_byte_group_nb_functions=3, + cross_attn_encoder=cross_attention, # True, + cross_attn_decoder=cross_attention, # True, + cross_attn_window_encoder=512, + cross_attn_window_decoder=512, + dim_local_encoder=256, + dim_local_decoder=256, + cross_attn_k=8, + cross_attn_nheads=4, + cross_attn_all_layers_decoder=True, + cross_attn_all_layers_encoder=True, + cross_attn_use_flex_attention=True, + cross_attn_init_by_pooling=True, + log_patch_lengths=True, + non_linearity="swiglu", + use_rope=True, + recompute_fc1_out=False, + recompute_fc3_out=False, + recompute_attn=False, + custom_bwd=False, + layer_ckpt="none", + efficient_attn="sdpa", + patch_only_encoder=False, + patch_only_decoder=False, + use_local_encoder_transformer=True, + init_use_gaussian=True, + init_use_depth="current", + attn_bias_type="block_causal", + alpha_depth="disabled", + max_length=256, + local_attention_window_len=512, + max_seqlen=12288, + downsampling_by_pooling="max", + ) + return transformer_args + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +class TestByteLatentTransformer: + def test_local_encoder(self): + args = create_args() + device = torch.device("cuda") + local_encoder = create_local_encoder(args).to(device) + + batch = fake_batch() + tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) + + local_encoder_tokens, _, _ = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=0, + patch_size=local_encoder.patch_size, + boe_id=local_encoder.boe_id, + ) + + patch_ids = patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + + encoder_hash_tok_embedding = init_embeddings( + args, + EmbeddingType.HASH_TOK, + local_encoder_dim=local_encoder.dim, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + ).to(device) + + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=local_encoder, + encoder_hash_tok_embedding=encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab, + ) + + reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt") + reference_tokens = torch.load(reference_path).to(device) + torch.testing.assert_close( + local_encoder_tokens, + reference_tokens, + msg="Generated tokens don't match reference tokens", + ) + + (h_encoder, h_cross), cache_encoder = local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=None, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + assert h_encoder is not None + assert h_cross is None + assert cache_encoder is None + + expected_shape = ( + local_encoder_tokens.shape[0], + local_encoder_tokens.shape[1], + local_encoder.dim, + ) + assert h_encoder.shape == expected_shape + + def test_local_encoder_cross_attention(self): + args = create_args(cross_attention=True) + device = torch.device("cuda") + local_encoder = create_local_encoder(args).to(device) + + batch = fake_batch() + tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) + + local_encoder_tokens, _, _ = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=0, + patch_size=local_encoder.patch_size, + boe_id=local_encoder.boe_id, + ) + + patch_ids = patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + + encoder_hash_tok_embedding = init_embeddings( + args, + EmbeddingType.HASH_TOK, + local_encoder_dim=local_encoder.dim, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + ).to(device) + + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + local_encoder_tokens.shape[-1], + patches_as_queries=True, + cross_attn_k=args.cross_attn_k, + window=args.cross_attn_window_encoder, + block_mask=True, + ) + + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=local_encoder, + encoder_hash_tok_embedding=encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab, + ) + (h_encoder, h_cross), cache_encoder = local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + assert h_encoder is not None + assert h_cross is not None + assert cache_encoder is None + expected_shape = ( + local_encoder_tokens.shape[0], + local_encoder_tokens.shape[1], + local_encoder.dim, + ) + assert h_encoder.shape == expected_shape + assert h_cross.shape == (2, 2048, local_encoder.dim) + + def test_local_decoder_cross_attention(self): + args = create_args(cross_attention=True) + device = torch.device("cuda") + local_decoder = create_local_decoder(args).to(device) + + test_files = { + "dec_embeds": "dec_embeds.pt", + "decoder_tokens": "local_decoder_tokens.pt", + "patch_embeds": "decoder_patch_cross_embeds.pt", + } + batch = fake_batch() + _, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) + + tensors = { + name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) + for name, filename in test_files.items() + } + decoder_patch_ids = decoder_patch_ids_from_lengths( + patch_lengths, 0, tensors["decoder_tokens"].shape[-1] + ) + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + tensors["decoder_tokens"].shape[-1], + patches_as_queries=False, + cross_attn_k=args.cross_attn_k, + window=args.cross_attn_window_decoder, + block_mask=True, + ) + output, _ = local_decoder( + embeds=tensors["dec_embeds"], + patch_embeds=tensors["patch_embeds"], + tokens=tensors["decoder_tokens"], + cross_mask=cross_attn_mask_dec, + cache=None, + ) + assert output is not None + assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size) + + def test_local_decoder(self): + args = create_args() + device = torch.device("cuda") + + local_decoder = create_local_decoder(args).to(device) + + test_files = { + "dec_embeds": "dec_embeds.pt", + "decoder_tokens": "local_decoder_tokens.pt", + "patch_embeds": "decoder_patch_embeds.pt", + } + + tensors = { + name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) + for name, filename in test_files.items() + } + + output, cache_decoder = local_decoder( + embeds=tensors["dec_embeds"], + patch_embeds=tensors["patch_embeds"], + tokens=tensors["decoder_tokens"], + cross_mask=None, + cache=None, + ) + assert output is not None + expected_shape = ( + tensors["decoder_tokens"].shape[0], + tensors["decoder_tokens"].shape[1], + args.vocab_size, + ) + assert output.shape == expected_shape + assert cache_decoder is None + + def test_global_transformer(self): + args = create_args() + device = torch.device("cuda") + global_transformer = create_global_transformer(args).to(device) + + test_files = { + "global_embeds": "global_embeds.pt", + "global_tokens": "global_tokens.pt", + } + tensors = { + name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) + for name, filename in test_files.items() + } + h, cache = global_transformer( + embeds=tensors["global_embeds"], tokens=tensors["global_tokens"] + ) + h is not None + assert h.shape == (2, 256, 512) + assert cache is None + + def test_blt_transformer_init(self): + args = create_args() + model = ByteLatentTransformer(args) + assert model is not None + + @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) + def test_blt_transformer_forward(self, attn_type): + args = create_args() + args = args.model_copy(update=dict(efficient_attn=attn_type)) + model = ByteLatentTransformer(args) + model = model.cuda() + batch = fake_batch() + x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) + + output = model( + tokens=x, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + ) + assert output is not None + expected_shape = ( + x.shape[0], + x.shape[1], + args.vocab_size, + ) + assert output.shape == expected_shape + + def test_blt_transformer_cross_attn_forward(self): + args = create_args(cross_attention=True) + model = ByteLatentTransformer(args) + model = model.cuda() + batch = fake_batch() + x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) + + output = model( + tokens=x, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + ) + assert output is not None + expected_shape = ( + x.shape[0], + x.shape[1], + args.vocab_size, + ) + assert output.shape == expected_shape + + def test_cross_attention_rand(self): + x = torch.randn(2, 256, 512, device="cuda") + kv = torch.randn(2, 256, 512, device="cuda") + cross_attention = CrossAttention( + dim=512, + head_dim=64, + n_heads=8, + n_kv_heads=4, + norm_eps=1e-6, + ).to("cuda") + mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + output = cross_attention(x, kv, mask) + assert output is not None + assert output.shape == (2, 256, 512) + + def test_ngram_embeddings(self): + ngram_to_size = { + 2: 38396, + 3: 50000, + 4: 50000, + 5: 50000, + 6: 50000, + 7: 50000, + 8: 50000, + } + batch = fake_batch() + ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size) + ngram_ids = ngram_processor.encode_token_ngrams(batch.x) + ngram_ids = np.stack(ngram_ids, axis=0) + batch = replace(batch, ngram_ids=ngram_ids) + args = create_args(cross_attention=True) + args = args.model_copy( + update=dict( + encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000", + encoder_enable_byte_ngrams=True, + ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes, + ) + ) + model = ByteLatentTransformer(args) + model = model.cuda() + x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) + + output = model( + tokens=x, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + ) + assert output is not None + expected_shape = ( + x.shape[0], + x.shape[1], + args.vocab_size, + ) + assert output.shape == expected_shape + + def test_loss_backward(self): + args = create_args() + args = args.model_copy(update=dict(efficient_attn="sdpa")) + batch = fake_batch() + model = ByteLatentTransformer(args) + steps = 10 + optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps) + model = model.cuda() + x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) + + initial_loss = None + final_loss = None + for step in range(steps): + output = model( + tokens=x, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + ) + loss, _ = compute_loss(output, y, mask, 1.0) + if step == 0: + initial_loss = loss.item() + if step == steps - 1: + final_loss = loss.item() + prev_loss = loss.item() + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + assert ( + final_loss < initial_loss + ), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}" diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py new file mode 100644 index 0000000..3acc42d --- /dev/null +++ b/bytelatent/test_entropy_model.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os + +import torch + +from bytelatent.constants import BLT_DATA +from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState +from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator +from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum, entropy +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + +ENTROPY_MODEL = "transformer_100m" +ARROW_TEST_DATA = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") + + +def test_entropy_model(): + initial_state = ArrowFileIteratorState( + file_path=None, + num_workers=1, + worker_id=0, + preprocess_dir=None, + entropy_model_name=ENTROPY_MODEL, + dataset_files=[ARROW_TEST_DATA], + row_num=0, + arrow_batch_size=100, + ) + arrow_file = initial_state.build() + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + entropy_model = load_entropy_model( + BLT_DATA / "checkpoint_0100000_consolidated", + os.path.join( + BLT_DATA, + "entropy_model.pth", + ), + ) + preprocess_iter = PreprocessIterator( + arrow_file, + tokenizer_args=tokenizer_args, + patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.entropy), + add_patches=False, + ) + for example in preprocess_iter.create_iter(): + tokens = torch.tensor(example.tokens).unsqueeze(0) + expected_entropies = torch.tensor(example.entropies).unsqueeze(0) + preds = entropy_model(tokens) + pred_entropies = entropy(preds) + assert pred_entropies.shape == expected_entropies.shape + assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5) + break diff --git a/bytelatent/tokenizers/__init__.py b/bytelatent/tokenizers/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/tokenizers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/tokenizers/abstract_tokenizer.py b/bytelatent/tokenizers/abstract_tokenizer.py new file mode 100644 index 0000000..fff26c3 --- /dev/null +++ b/bytelatent/tokenizers/abstract_tokenizer.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import abc + + +class Tokenizer(abc.ABC): + @abc.abstractmethod + def encode(self, text: str, add_bos: bool, add_eos: bool): + pass + + @abc.abstractmethod + def decode(self, tokens: list[int]): + pass + + @abc.abstractmethod + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + """Return the offsets of the tokens in the original text. Only used for evaluation.""" + pass diff --git a/bytelatent/tokenizers/blt_tokenizer.py b/bytelatent/tokenizers/blt_tokenizer.py new file mode 100644 index 0000000..a3462e1 --- /dev/null +++ b/bytelatent/tokenizers/blt_tokenizer.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import re + +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.tokenizers.constants import ( + BOE_ID, + BOS_ID, + BPE_ID, + BYTE_UNITS, + EOS_ID, + OFFSET, + PAD_ID, +) +from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer + + +def convert_to_bytes(s): + # check if the output is a bytes like object of the format <0x00> + if re.match(r"<0x[0-9a-fA-F]+>", s): + return bytes.fromhex(s[3:-1]) + else: + return bytes(s, "utf-8", errors="ignore") + + +def text2bytes_bpe_delims( + text: str, + *, + bpe_tokenizer, + bpe_id: int, + offsetting_special_char: int, + add_bos: bool, + add_eos: bool, +): + cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) + # merge the leading space tokens + leading_space_tokens = [] + other_bpe_tokens = [] + leading = True + for token in cur_bpe: + bpe_str = bpe_tokenizer.sp_model.id_to_piece(token) + if leading and all(c == "▁" for c in bpe_str): + leading_space_tokens.append(bpe_str) + else: + leading = False + other_bpe_tokens.append(bpe_str) + cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens + + # Remove the '▁' characters + bpe_strs = [] + for i, bpe_str in enumerate(cur_bpe_strs): + if ( + len(bpe_strs) <= 1 + and all([c == " " for s in bpe_strs for c in s]) + and not all(c == "▁" for c in bpe_str) + ): + # Remove leading space for first non space token. + bpe_str = bpe_str.replace("▁", "") + elif i == 0 and all(c == "▁" for c in bpe_str): + bpe_str = " " * (len(text) - len(text.lstrip(" "))) + else: + bpe_str = bpe_str.replace("▁", " ") + if len(bpe_str) > 0: + bpe_strs.append(bpe_str) + ex_seq = [] + # Convert bpe tokens to bytes + for s in bpe_strs: + byte_chunk = convert_to_bytes(s) + proc_chunk = [int(unit) for unit in byte_chunk] + ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk) + + return ex_seq + + +class BltTokenizer(Tokenizer): + def __init__( + self, + *, + vocab_size_unit_1: int = BYTE_UNITS, + bpe_delim: bool = False, + bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", + add_bos: bool = True, + add_eos: bool = True, + ): + self.add_bos = add_bos + self.add_eos = add_eos + self.vocab_size_unit_1 = vocab_size_unit_1 + self.boe_id = BOE_ID + self.bos_id = BOS_ID + self.eos_id = EOS_ID + self.pad_id = PAD_ID + self.bpe_id = BPE_ID + self.bpe_tokenizer_path = bpe_tokenizer_path + if bpe_delim: + self.bpe_tokenizer = SentencePieceTokenizer( + model_path=self.bpe_tokenizer_path + ) + else: + self.bpe_tokenizer = None + self.bpe_delim = bpe_delim + self.offsetting_special_char = OFFSET + self.vocab_size_unit_1 = vocab_size_unit_1 + self.n_words = vocab_size_unit_1 + self.offsetting_special_char + + def encode( + self, text: str, add_bos: bool | None = None, add_eos: bool | None = None + ): + if add_bos is None: + add_bos = self.add_bos + if add_eos is None: + add_eos = self.add_eos + + if self.bpe_delim: + tokens = text2bytes_bpe_delims( + text, + bpe_tokenizer=self.bpe_tokenizer, + bpe_id=self.bpe_id, + offsetting_special_char=self.offsetting_special_char, + add_bos=False, + add_eos=False, + ) + else: + tokens = bytes(text, encoding="utf-8", errors="ignore") + + # Offsetting + tokens = [int(unit) + self.offsetting_special_char for unit in tokens] + + if add_bos: + tokens.insert(0, self.bos_id) + if add_eos: + tokens.append(self.eos_id) + + return tokens + + def decode(self, tokens: list[int], cut_at_eos: bool = False): + if cut_at_eos: + for k, t in enumerate(tokens): + if t == self.eos_id: + tokens = tokens[: k + 1] + break + return bytes( + [ + tok - self.offsetting_special_char + for tok in tokens + if tok - self.offsetting_special_char >= 0 + ] + ).decode("utf-8", errors="ignore") + + def get_token_offsets(self, text: str, tokens: list[int] | None = None): + # TODO: Figure out what this does + raise NotImplementedError() diff --git a/bytelatent/tokenizers/build_tokenizer.py b/bytelatent/tokenizers/build_tokenizer.py new file mode 100644 index 0000000..8aa434d --- /dev/null +++ b/bytelatent/tokenizers/build_tokenizer.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +from typing import Any + +from pydantic import BaseModel + +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer +from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer +from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer + +try: + from sentencepiece import SentencePieceProcessor + + has_sp = True +except ImportError: + has_sp = False + +try: + import tiktoken + from tiktoken.load import load_tiktoken_bpe + + has_tiktoken = True +except ImportError: + has_tiktoken = False + +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer + +logger = logging.getLogger(__name__) + + +class MockTokenizer(Tokenizer): + n_words: int = 256 + + def encode(self, text: str, add_bos: bool, add_eos: bool): + return text + + def decode(self, tokens): + raise NotImplementedError() + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str]]: + raise NotImplementedError() + + +class TokenizerArgs(BaseModel): + name: str = "bytes" + init_kwargs: dict[str, Any] | None = None + + def build(self) -> Tokenizer: + if self.init_kwargs is None: + init_kwargs = {} + else: + init_kwargs = self.init_kwargs + if self.name == "blt": + return BltTokenizer(**init_kwargs) + elif self.name == "bytes": + return ByteTokenizer(**init_kwargs) + elif self.name == "mock": + return MockTokenizer(**init_kwargs) + elif self.name == "sp": + assert has_sp, "sentencepiece not installed" + return SentencePieceTokenizer(**init_kwargs) + elif self.name == "tiktoken": + assert has_tiktoken, "tiktoken not installed" + return TikTokenTokenizer(**init_kwargs) + else: + raise NotImplementedError(f"{self.name} tokenizer type is not implemented") diff --git a/bytelatent/tokenizers/byte_tokenizer.py b/bytelatent/tokenizers/byte_tokenizer.py new file mode 100644 index 0000000..f85f4f7 --- /dev/null +++ b/bytelatent/tokenizers/byte_tokenizer.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer + + +class ByteTokenizer(Tokenizer): + def __init__(self): + self.bos_id = 256 + self.eos_id = 257 + self.n_words = 258 + + def encode(self, s: str, add_bos: bool = False, add_eos: bool = False): + tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos + return tokens + + def decode(self, tokens: list[int]): + byte_tokens = bytes([t for t in tokens if t < 256]) + return byte_tokens.decode("utf-8", errors="backslashreplace") + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + if tokens is None: + tokens = self.encode(text) + + decoded_chars, offsets = [], [] + byte_pos = 0 + for token in tokens: + if token < 256: + char = bytes([token]).decode("utf-8", errors="ignore") + if char: + decoded_chars.append(char) + offsets.append(byte_pos) + byte_pos += len(char.encode("utf-8")) + + return decoded_chars, offsets diff --git a/bytelatent/tokenizers/constants.py b/bytelatent/tokenizers/constants.py new file mode 100644 index 0000000..774119e --- /dev/null +++ b/bytelatent/tokenizers/constants.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + + +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 + +BYTE_UNITS: int = 256 diff --git a/bytelatent/tokenizers/sentence_piece_tokenizer.py b/bytelatent/tokenizers/sentence_piece_tokenizer.py new file mode 100644 index 0000000..faeb997 --- /dev/null +++ b/bytelatent/tokenizers/sentence_piece_tokenizer.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +import os + +try: + from sentencepiece import SentencePieceProcessor + + has_sp = True +except ImportError: + has_sp = False + +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer + +logger = logging.getLogger(__name__) + + +class SentencePieceTokenizer(Tokenizer): + def __init__( + self, model_path: str, add_bos: bool = True, add_eos: bool = True + ) -> None: + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + self.add_bos = add_bos + self.add_eos = add_eos + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None): + if add_bos is None: + add_bos = self.add_bos + + if add_eos is None: + add_eos = self.add_eos + assert type(s) is str + tokens = ( + [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos + ) + return tokens + + def decode(self, tokens: list[int]): + return self.sp_model.decode(tokens) + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + pieces = self.sp_model.encode_as_immutable_proto(text).pieces + substrs = [p.surface for p in pieces] + offsets = [p.begin for p in pieces] + return substrs, offsets diff --git a/bytelatent/tokenizers/test_blt_tokenizer.py b/bytelatent/tokenizers/test_blt_tokenizer.py new file mode 100644 index 0000000..50308da --- /dev/null +++ b/bytelatent/tokenizers/test_blt_tokenizer.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import json + +from bytelatent.constants import BLT_DATA +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def test_tokenizer_bytes(): + with open("fixtures/tokenizer_data.json") as f: + data = json.load(f) + + examples: list[str] = data["texts"] + examples_tokens: list[list[int]] = data["tokens"] + + tokenizer = BltTokenizer(bpe_delim=False) + for i in range(len(examples)): + assert tokenizer.encode(examples[i]) == examples_tokens[i] + + +def test_tokenizer_bpe(): + with open("fixtures/tokenizer_data_bpe_delim.json") as f: + data = json.load(f) + + examples: list[str] = data["texts"] + examples_tokens: list[list[int]] = data["tokens"] + + tokenizer = BltTokenizer(bpe_delim=True) + for i in range(len(examples)): + assert tokenizer.encode(examples[i]) == examples_tokens[i] + + +def test_build_tokenizer_from_args(): + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + tokenizer = tokenizer_args.build() + assert tokenizer.encode("test text") is not None diff --git a/bytelatent/tokenizers/tiktoken_tokenizer.py b/bytelatent/tokenizers/tiktoken_tokenizer.py new file mode 100644 index 0000000..f498bf2 --- /dev/null +++ b/bytelatent/tokenizers/tiktoken_tokenizer.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +from copy import copy +from pathlib import Path + +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer + +try: + import tiktoken + from tiktoken.load import load_tiktoken_bpe + + has_tiktoken = True +except ImportError: + has_tiktoken = False +DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +DEFAULT_TIKTOKEN_SPECIAL_TOKENS = { + "<|begin_of_text|>": 0, + "<|end_of_text|>": 1, + "<|fim_prefix|>": 2, + "<|fim_middle|>": 3, + "<|fim_end_fill|>": 253, + "<|fim_pad|>": 254, + "<|fim_suffix|>": 255, +} +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +logger = logging.getLogger(__name__) + + +class TikTokenTokenizer(Tokenizer): + def __init__(self, model_path: str) -> None: + mergeable_ranks = load_tiktoken_bpe(model_path) + all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS) + missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values()) + for id in missing_ids: + all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id + for name in all_special_tokens_with_ids: + all_special_tokens_with_ids[name] += len(mergeable_ranks) + + self.tkt_model = tiktoken.core.Encoding( + name=Path(model_path).stem, + pat_str=DEFAULT_TIKTOKEN_PATTERN, + mergeable_ranks=mergeable_ranks, + special_tokens=all_special_tokens_with_ids, + ) + + self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>") + self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>") + + self.n_words: int = self.tkt_model.n_vocab + + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + + def encode(self, s: str, add_bos: bool, add_eos: bool): + assert isinstance(s, str) + + subs = [] + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): + subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) + return ( + [self.bos_id] * add_bos + + sum(self.tkt_model.encode_ordinary_batch(subs), start=[]) + + [self.eos_id] * add_eos + ) + + def decode(self, tokens: list[int]): + return self.tkt_model.decode(tokens) + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + if tokens is not None: + token_bytes = self.tkt_model.decode_tokens_bytes(tokens) + else: + token_bytes = self.tkt_model.decode_tokens_bytes( + self.tkt_model.encode(text, allowed_special="all") + ) + + text_len, offsets = 0, [] + for token in token_bytes: + offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0))) + text_len += sum(1 for c in token if not 0x80 <= c < 0xC0) + substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])] + return substrs, offsets diff --git a/bytelatent/train.py b/bytelatent/train.py new file mode 100644 index 0000000..6cb13b9 --- /dev/null +++ b/bytelatent/train.py @@ -0,0 +1,651 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import gc +import logging +import os +import sys +from contextlib import ExitStack +from copy import deepcopy +from dataclasses import asdict, dataclass +from pathlib import Path +from timeit import default_timer as timer +from typing import Any, Dict, Type, TypeVar + +import torch +import torch.distributed +import torch.nn.functional +import torch.nn.functional as F +import wandb +import xformers.profiler +from omegaconf import OmegaConf +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim import lr_scheduler + +from bytelatent.args import TrainArgs +from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.data_types import DataLoaderState +from bytelatent.distributed import ( + check_model_value_range, + clean_env, + dist_mean_dict, + get_device_mesh, + get_is_master, + get_world_size, + init_signal_handler, + parallelize_model, + requeue_slurm_job, + setup_env, + setup_torch_distributed, +) +from bytelatent.logger import init_logger +from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.optim import build_optimizer +from bytelatent.probe import AutoProbeD +from bytelatent.profiling import maybe_run_profiler +from bytelatent.stool import StoolArgs, launch_job +from bytelatent.transformer import ( + build_fsdp_grouping_plan, + get_no_recompute_ops, + get_num_flop_per_token, + tp_parallelize, +) + +logger = logging.getLogger() + +T = TypeVar("T") + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: + """ + Converts a dictionary to a dataclass instance, recursively for nested structures. + """ + base = OmegaConf.structured(cls()) + OmegaConf.set_struct(base, strict) + override = OmegaConf.create(data) + return OmegaConf.to_object(OmegaConf.merge(base, override)) + + +@dataclass +class TrainState(Stateful): + step: int # Nb of steps taken by the optimizer + acc_step: int # Nb of accumulation steps done since last optimizer step + scheduler: lr_scheduler.LambdaLR + data_loader_state: DataLoaderState + scale: float = 1.0 + + def state_dict(self) -> Dict[str, Any]: + return { + "step": self.step, + "acc_step": self.acc_step, + "data_loader_state": self.data_loader_state.dict(), + "scheduler": self.scheduler.state_dict(), + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.acc_step = state_dict["acc_step"] + self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) + self.scheduler.load_state_dict(state_dict["scheduler"]) + + +def validate_train_args(args: TrainArgs, output_size: int): + if args.model.vocab_size < 0: + logger.info(f"Setting model output size to {args.model.vocab_size}") + args.model.vocab_size = output_size + + assert args.dump_dir, "Dump dir not set" + + if args.checkpoint.path is None: + logger.info(f"Setting checkpoint path to {args.checkpoint.path}") + args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + + for source in args.data.sources: + data_path = os.path.join(args.data.root_dir, source) + assert os.path.exists(data_path), f"{data_path} doesn't exist" + + if ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + != get_world_size() + ): + assert get_world_size() % args.distributed.dp_shard == 0 + args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + + assert args.distributed.dp_replicate % args.distributed.tp_size == 0 + args.distributed.dp_replicate = ( + args.distributed.dp_replicate // args.distributed.tp_size + ) + + logger.warning( + f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" + ) + assert ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + == get_world_size() + ) + + if args.distributed.fsdp_type == "no_shard": + assert ( + args.distributed.dp_shard == 1 + and args.distributed.dp_replicate == get_world_size() + ) + + args.model.max_seqlen = args.data.seq_len + + if args.distributed.tp_size == 1: + logger.warning( + "Tensor parallelism has not been tested for a while, use at your own risk" + ) + + assert ( + args.probe_freq != args.profiling.mem_steps + ), "Don't profile during probe step" + assert ( + args.probe_freq != args.profiling.profile_steps + ), "Don't profile during probe step" + if args.logging.wandb is not None: + args.logging.wandb.name = args.name + + if args.probe_freq is not None: + assert ( + args.distributed.tp_size == 1 + ), "Probing not supported with tensor parallelism" + assert ( + args.distributed.selective_activation_checkpointing is False + ), "Probing not supported with selective activation checkpointing" + + +preemption_flag = dict(flag=False) + + +def set_preemption_flag(signum, frame): + logger.warning("Signal handler called with signal " + str(signum)) + logger.warning("Preemption ! checkpointing asap and exiting.") + preemption_flag["flag"] = True + + +def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): + test = train_state.step % freq == 0 + if acc_step is not None: + test = test and (train_state.acc_step == acc_step) + elif acc_freq is not None: + test = test and ((train_state.acc_step % acc_freq) == 0) + return test + + +def compute_loss(p, y, mask, scale): + tok_loss = scale * F.cross_entropy( + p.flatten(0, 1), y.flatten(0, 1), reduction="none" + ) + if mask is None: + loss = tok_loss.mean() + else: + mask = mask.flatten(0, 1) + tok_loss = tok_loss * mask + loss = tok_loss.sum() / (mask.sum() + 1e-6) + return loss, tok_loss + + +def train(args: TrainArgs): + with ExitStack() as context_stack: + tokenizer = args.data.tokenizer_args.build() + validate_train_args( + args, + tokenizer.n_words, + ) + if get_is_master(): + os.makedirs(args.dump_dir, exist_ok=True) + args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml") + init_logger(Path(args.dump_dir) / "train.log") + init_signal_handler(set_preemption_flag) # For handling preemption signals. + setup_env(args.env) + setup_torch_distributed(args.distributed) + world_mesh = get_device_mesh(args.distributed) + logger.info(f"Starting job: {args.name}") + + # build dataloader + # need dp world size and rank + dp_mesh = world_mesh["dp_replicate"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + if args.distributed.dp_shard > 1: + dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() + dp_degree *= world_mesh["dp_shard"].size() + + logger.info(f"Running on dp rank : {dp_rank}") + logger.info(f"Running on dp size : {dp_degree}") + + torch.manual_seed(args.seed) + logger.info("Building model") + + # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory + with torch.device("meta"): + model = ByteLatentTransformer(args.model) + logger.info("Model is built !") + + model_param_count = get_num_params(model) + + model = parallelize_model( + model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + tp_parallelize=tp_parallelize, + no_recompute_ops=get_no_recompute_ops(), + ) + + # Once we shard the model on different gpus we can actually initialize the model + # First we create empty tensors of the correct shapes + model = model.to_empty(device="cuda") + # Then we init the model. Please make sure this function initializes *ALL* parameters + # and buffers, otherwise you will have random values in the unitialized tensors + # which will silently fail (give nan gradients for example) + + if args.checkpoint.init_ckpt_path: + logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + load_from_checkpoint( + args.checkpoint.init_ckpt_path, model, model_key="model" + ) # Put model_key="" if its directly the model checkpoint + model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded + else: + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(args.model.seed) + model.init_weights() + check_model_value_range(model, range=10.0, std=1.0) + + # log model size + + logger.info(f"Model size: {model_param_count:,} total parameters") + + gpu_memory_monitor = GPUMemoryMonitor("cuda") + logger.info( + f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " + f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" + ) + logger.info(f"GPU memory usage: {gpu_memory_monitor}") + + # build optimizer after apply parallelisms to the model + optimizer, scheduler = build_optimizer(model, args.optim, args.steps) + data_loader = args.data.build_from_rank(dp_rank, dp_degree) + data_loader_state = data_loader.get_state() + + train_state = TrainState( + step=0, + acc_step=0, + data_loader_state=data_loader_state, + scheduler=scheduler, + scale=1.0, + ) + + checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) + checkpoint.load(model, optimizer, train_state, world_mesh) + # Either load from latest checkpoint or start from scratch + if args.probe_freq is not None: + if get_is_master(): + os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + torch.distributed.barrier() + probe = AutoProbeD( + model, + ( + Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + if (dp_rank % 128 == 0) + else None + ), + ) + probe_mod = model._orig_mod if args.distributed.compile else model + + gc.disable() + + # train loop + model.train() + metric_logger = context_stack.enter_context( + MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + ) + data_loader = train_state.data_loader_state.build() + batch_iterator = data_loader.create_iter() + + torch_profiler = context_stack.enter_context( + maybe_run_profiler(args.dump_dir, model, args.profiling) + ) + + nwords_since_last_log = 0 + time_last_log = timer() + gc.collect() + while train_state.step < args.steps: + # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 + train_state.acc_step += 1 + train_state.acc_step = train_state.acc_step % args.grad_acc_steps + + # get batch + curr_lr = float(optimizer.param_groups[0]["lr"]) + data_load_start = timer() + batch = next(batch_iterator) + batch_x = torch.from_numpy( + batch.x, + ).cuda() + batch_y = torch.from_numpy(batch.y).cuda() + batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() + mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + + if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None: + raise ValueError( + "Cannot enable byte ngrams and have batch.ngram_ids be None" + ) + ngram_ids = ( + None + if batch.ngram_ids is None + else torch.from_numpy(batch.ngram_ids).cuda() + ) + + if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): + logger.info("garbage collection") + # we do garbage collection manually otherwise different processes + # run the GC at different times so they slow down the whole pipeline + gc.collect() + + data_load_time = round(timer() - data_load_start, 4) + nwords_since_last_log += batch_x.numel() + + bsz, seqlen = batch_y.shape + + # forward + start_timer = torch.cuda.Event(enable_timing=True) + end_timer = torch.cuda.Event(enable_timing=True) + start_timer.record() + + # This is an automatic probe that will compute statistics + # of all linears' inputs, weights and outputs + # along with attention logits and entropy + # both in forward and backward pass + tok_loss = None + if (args.probe_freq is not None) and every_n_steps( + train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps + ): + # Here we do a fake forward and backward pass on a smaller + # batch size to avoid OOM + # This assumes the model has no stateful layers (batch norm..) + assert ( + next(probe_mod.parameters()).grad is None + ), "Can't probe model if grads are not reset" + + with probe: + probe.metadata = { + "it": train_state.step, + "global_step": train_state.step, + "loop": "lingua", + } + # Non compiled model uses roughly 2x memory in our exps + # So we divide bsz by 2 or seqlen by 2 + probe_bsz = max(1, bsz // 2) + probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2) + probe_loss = probe_mod( + batch_x[:probe_bsz, :probe_seq], + batch_y[:probe_bsz, :probe_seq], + ) + probe_loss.backward() + # We zero grads to cancel this fake step + optimizer.zero_grad() + + assert ( + next(probe_mod.parameters()).grad is None + ), "Probe model shouldn't have grads at this point" + + pred = model( + batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids + ) + + loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = loss / args.grad_acc_steps + + # backward on scaled loss to create scaled gradients + loss.backward() + # For logging we undo that scaling + loss = loss.detach() * args.grad_acc_steps + + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + + grad_norm = ( + grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm + ).item() + + # optimizer step + if train_state.acc_step == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + train_state.step += 1 + + # updates the scale for next iteration + # training iteration complete + end_timer.record() + + torch.cuda.synchronize() + + curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) + + # if profiler is active + if torch_profiler: + xformers.profiler.step() + + # log metrics + if every_n_steps( + train_state, + args.logging.freq, + acc_step=None if args.logging.acc_freq else 0, + acc_freq=args.logging.acc_freq, + ): + time_delta = timer() - time_last_log + wps = nwords_since_last_log / (time_delta * args.distributed.tp_size) + + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + + total_acc_steps = ( + args.grad_acc_steps * train_state.step + train_state.acc_step + ) + tokens_per_gpu = ( + total_acc_steps * args.data.batch_size * args.data.seq_len + ) + total_tokens = dp_degree * tokens_per_gpu + # This is an estimate and the correct values may change + # if you change the architecture + # Use xformer's analyze profile trace to get actual measurement + FLOPS = ( + get_num_flop_per_token( + model_param_count - args.model.vocab_size * args.model.dim, + args.model.n_layers, + args.model.dim, + args.data.seq_len, + ) + * wps + ) + + metrics = flatten_dict( + { + "global_step": train_state.step, + "acc_step": train_state.acc_step, + "speed": { + "wps": wps, + "FLOPS": FLOPS, + "curr_iter_time": curr_iter_time, + "data_load_time": data_load_time, + }, + "optim": { + "grad_norm": grad_norm, + "lr": curr_lr, + "total_tokens": total_tokens, + }, + "memory": gpu_mem_stats._asdict(), + }, + sep="/", + ) + + to_sync = {} + to_sync["loss/out"] = loss.item() + metrics.update(dist_mean_dict(to_sync)) + + if get_is_master(): + metric_logger.log(metrics) + + gpu_memory_monitor.reset_peak_stats() + nwords_since_last_log = 0 + time_last_log = timer() + logger.info( + f"step: {train_state.step}" + f" acc: {train_state.acc_step}" + f" loss: {round(loss.item(),4):>7}" + f" grad: {grad_norm:.2e}" + f" flops: {FLOPS:.2e}" + f" wps: {wps:.2e}" + f" iter: {curr_iter_time:>7}" + f" data: {data_load_time:>5}" + f" lr: {curr_lr:.2e}" + f" mem: {gpu_mem_stats.max_active_pct:.0f}%" + f" pow: {gpu_mem_stats.power_draw/1000} W" + ) + + saved = False + if every_n_steps( + train_state, args.checkpoint.dump.every, acc_step=0 + ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + saved = checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + + if args.eval is not None and every_n_steps( + train_state, args.checkpoint.eval.every, acc_step=0 + ): + from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + + eval_args = dataclass_from_dict(EvalArgs, args.eval) + + eval_args.global_step = train_state.step + eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) + eval_args.dump_dir = str( + os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), + ) + ) + eval_args.metric_log_dir = args.dump_dir + if args.async_eval_gpus is None: + launch_eval(eval_args) + elif get_is_master(): + if wandb.run is not None and args.logging.wandb is not None: + eval_args.wandb = deepcopy(args.logging.wandb) + assert args.async_eval_gpus > 0 + logger.info(f"Launching evals on {args.async_eval_gpus} gpus") + with clean_env(): + launch_job( + StoolArgs( + asdict(eval_args), + script="apps.main.eval", + copy_code=False, + nodes=args.async_eval_gpus // 8, + qos="lowest", + ) + ) + + if preemption_flag["flag"]: + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + requeue_slurm_job() + sys.exit(0) + + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + gc.collect() + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate TrainArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call train.py with train.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in TrainArgs dataclass. + """ + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.create(TrainArgs().model_dump()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + train_args = TrainArgs.model_validate(cfg) + train(train_args) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py new file mode 100644 index 0000000..432f7df --- /dev/null +++ b/bytelatent/transformer.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, +) +from torch.nn.attention.flex_attention import BlockMask, create_block_mask +from xformers.ops import AttentionBias, fmha + +from bytelatent.base_transformer import ( + BaseTransformer, + BaseTransformerArgs, + RMSNorm, + cross_entropy, +) + + +def create_causal_mask(seqlen, attn_impl, sliding_window): + if sliding_window is not None and attn_impl == "xformers": + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) + elif attn_impl == "xformers": + return fmha.attn_bias.LowerTriangularMask() + elif attn_impl == "sdpa": + return "causal" + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) + + +def attention_flops_per_token(n_layers, seq_len, dim, causal): + # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 + return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1)) + + +def get_num_flop_per_token( + num_non_embed_params: int, n_layers: int, dim: int, seq_len: int +) -> int: + return 6 * num_non_embed_params + attention_flops_per_token( + n_layers, seq_len, dim, True + ) + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +class LMTransformerArgs(BaseTransformerArgs): + seed: int = 42 + + vocab_size: int = -1 + weight_tying: bool = False + + sliding_window: int | None = None + + +class LMTransformer(BaseTransformer): + def __init__(self, args: LMTransformerArgs): + super().__init__(args) + self.weight_tying = args.weight_tying + self.sliding_window = args.sliding_window + + assert args.vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + self.output = nn.Linear( + args.dim, + args.vocab_size, + bias=False, + ) + + if args.weight_tying: + self.output.weight = self.embeddings.tok_embeddings.weight + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, + attn_impl: str = "sdpa", + ): + bsz, seqlen = token_values.shape + + h = self.tok_embeddings(token_values) + + mask = ( + mask + if mask is not None + else create_causal_mask(seqlen, attn_impl, self.sliding_window) + ) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + + logits = self.output(self.norm(h)) + if target is not None: + return cross_entropy(logits, target) + else: + return logits + + def reset_parameters(self, init_std=None): + # Either use fixed base std or sqrt model dim + super().reset_parameters() + init_std = init_std or (self.dim ** (-0.5)) + self.norm.reset_parameters() + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + +# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops) +def get_no_recompute_ops(): + return None + + +# Optional and only used for fully shard options (fsdp) is choose. Highly recommanded for large models +def build_fsdp_grouping_plan(model_args: LMTransformerArgs): + group_plan: Tuple[int, bool] = [] + + # Grouping and output seperately + group_plan.append(("tok_embeddings", False)) + + # Grouping by layers + for i in range(model_args.n_layers): + group_plan.append((f"layers.{i}", False)) + + group_plan.append(("output", True)) + + return group_plan + + +# Optional and only used for model/tensor parallelism when tp_size > 1 +def tp_parallelize(model, tp_mesh, model_args: LMTransformerArgs, distributed_args): + assert model_args.dim % distributed_args.tp_size == 0 + assert model_args.vocab_size % distributed_args.tp_size == 0 + assert model_args.n_heads % distributed_args.tp_size == 0 + assert (model_args.n_kv_heads or 0) % distributed_args.tp_size == 0 + assert model_args.n_heads % (model_args.n_kv_heads or 1) == 0 + + # Embedding layer tp + main_plan = {} + main_plan["tok_embeddings"] = ColwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ) + main_plan["norm"] = SequenceParallel() + main_plan["output"] = ColwiseParallel( + input_layouts=Shard(1), output_layouts=Replicate() + ) + + parallelize_module( + model, + tp_mesh, + main_plan, + ) + + # Attention layers tp + for layer in model.layers: + layer_plan = {} + + layer_plan["attention"] = PrepareModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ) + layer_plan["attention_norm"] = SequenceParallel() + layer_plan["attention.wq"] = ColwiseParallel() + layer_plan["attention.wk"] = ColwiseParallel() + layer_plan["attention.wv"] = ColwiseParallel() + layer_plan["attention.wo"] = RowwiseParallel(output_layouts=Shard(1)) + + # Feedforward layers tp + layer_plan["feed_forward"] = PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ) + layer_plan["ffn_norm"] = SequenceParallel() + layer_plan["feed_forward.w1"] = ColwiseParallel() + layer_plan["feed_forward.w3"] = ColwiseParallel() + layer_plan["feed_forward.w2"] = RowwiseParallel(output_layouts=Shard(1)) + + parallelize_module( + layer, + tp_mesh, + layer_plan, + ) + + # Adjusting the number of heads and kv heads according to the tp size + attn_layer = layer.attention + attn_layer.n_heads = attn_layer.n_heads // distributed_args.tp_size + attn_layer.n_kv_heads = attn_layer.n_kv_heads // distributed_args.tp_size diff --git a/dev/lint.sh b/dev/lint.sh new file mode 100644 index 0000000..cde5ce1 --- /dev/null +++ b/dev/lint.sh @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +isort . +black . diff --git a/fixtures/tokenizer_data.json b/fixtures/tokenizer_data.json new file mode 100644 index 0000000..1aeb9e2 --- /dev/null +++ b/fixtures/tokenizer_data.json @@ -0,0 +1 @@ +{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 80, 105, 120, 43, 119, 36, 103, 108, 105, 103, 111, 36, 109, 106, 36, 120, 108, 105, 119, 105, 36, 120, 115, 111, 105, 114, 109, 126, 105, 118, 119, 36, 113, 101, 120, 103, 108, 37, 2]]} \ No newline at end of file diff --git a/fixtures/tokenizer_data_bpe_delim.json b/fixtures/tokenizer_data_bpe_delim.json new file mode 100644 index 0000000..c726f6e --- /dev/null +++ b/fixtures/tokenizer_data_bpe_delim.json @@ -0,0 +1 @@ +{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 3, 80, 105, 120, 3, 43, 3, 119, 3, 36, 103, 108, 105, 103, 111, 3, 36, 109, 106, 3, 36, 120, 108, 105, 119, 105, 3, 36, 120, 115, 111, 105, 114, 3, 109, 126, 105, 118, 119, 3, 36, 113, 101, 120, 103, 108, 3, 37, 2]]} \ No newline at end of file diff --git a/fixtures/tokens_with_entropies.json b/fixtures/tokens_with_entropies.json new file mode 100644 index 0000000..e3589d8 --- /dev/null +++ b/fixtures/tokens_with_entropies.json @@ -0,0 +1 @@ +{"position":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":6,"7":7,"8":8,"9":9,"10":10,"11":11,"12":12,"13":13,"14":14,"15":15,"16":16,"17":17,"18":18,"19":19,"20":20,"21":21,"22":22,"23":23,"24":24,"25":25,"26":26,"27":27,"28":28,"29":29,"30":30,"31":31,"32":32,"33":33,"34":34,"35":35,"36":36,"37":37,"38":38,"39":39,"40":40,"41":41,"42":42,"43":43,"44":44,"45":45,"46":46,"47":47,"48":48,"49":49,"50":50,"51":51,"52":52,"53":53,"54":54,"55":55,"56":56,"57":57,"58":58,"59":59,"60":60,"61":61,"62":62,"63":63,"64":64,"65":65,"66":66,"67":67,"68":68,"69":69,"70":70,"71":71,"72":72,"73":73,"74":74,"75":75,"76":76,"77":77,"78":78,"79":79,"80":80},"tokens":{"0":"<","1":"D","2":"a","3":"e","4":"n","5":"e","6":"r","7":"y","8":"s","9":"_","10":"T","11":"a","12":"r","13":"g","14":"a","15":"r","16":"y","17":"e","18":"n","19":"_","20":"i","21":"s","22":"_","23":"i","24":"n","25":"_","26":"G","27":"a","28":"m","29":"e","30":"_","31":"o","32":"f","33":"_","34":"T","35":"h","36":"r","37":"o","38":"n","39":"e","40":"s","41":",","42":"_","43":"a","44":"_","45":"f","46":"a","47":"n","48":"t","49":"a","50":"s","51":"y","52":"_","53":"e","54":"p","55":"i","56":"c","57":"_","58":"b","59":"y","60":"_","61":"G","62":"e","63":"o","64":"r","65":"g","66":"e","67":"_","68":"R","69":".","70":"R","71":".","72":"_","73":"M","74":"a","75":"r","76":"t","77":"i","78":"n","79":".","80":">"},"token_ids":{"0":1,"1":72,"2":101,"3":105,"4":114,"5":105,"6":118,"7":125,"8":119,"9":36,"10":88,"11":101,"12":118,"13":107,"14":101,"15":118,"16":125,"17":105,"18":114,"19":36,"20":109,"21":119,"22":36,"23":109,"24":114,"25":36,"26":75,"27":101,"28":113,"29":105,"30":36,"31":115,"32":106,"33":36,"34":88,"35":108,"36":118,"37":115,"38":114,"39":105,"40":119,"41":48,"42":36,"43":101,"44":36,"45":106,"46":101,"47":114,"48":120,"49":101,"50":119,"51":125,"52":36,"53":105,"54":116,"55":109,"56":103,"57":36,"58":102,"59":125,"60":36,"61":75,"62":105,"63":115,"64":118,"65":107,"66":105,"67":36,"68":86,"69":50,"70":86,"71":50,"72":36,"73":81,"74":101,"75":118,"76":120,"77":109,"78":114,"79":50,"80":2},"entropies":{"0":3.3949158192,"1":2.1656451225,"2":2.3216569424,"3":2.8214058876,"4":1.5249242783,"5":0.0401624143,"6":0.0981037766,"7":0.0544578359,"8":0.3430138826,"9":1.0546212196,"10":0.25252828,"11":0.1494535804,"12":0.0624754503,"13":0.001355894,"14":0.0050173439,"15":0.0052358187,"16":0.0011725067,"17":0.0010307421,"18":1.0241208076,"19":3.6867966652,"20":0.4502205253,"21":0.0484119244,"22":2.2572875023,"23":0.3789347112,"24":1.0042934418,"25":2.9090054035,"26":1.8933598995,"27":1.3859074116,"28":0.3827198744,"29":0.2646365762,"30":1.7742085457,"31":0.0136727821,"32":0.0053820172,"33":0.5485631227,"34":0.2064044327,"35":0.0049266233,"36":0.0005439016,"37":0.0007023578,"38":0.0004170335,"39":0.0054524317,"40":1.1938130856,"41":0.0238215197,"42":3.1279797554,"43":1.3883389235,"44":3.0503094196,"45":1.695879817,"46":1.8551058769,"47":1.4570231438,"48":0.0047810897,"49":0.026396824,"50":0.6633765101,"51":0.3141393065,"52":2.8411159515,"53":1.143143177,"54":0.0520330966,"55":0.3398066461,"56":0.4140175879,"57":2.5563707352,"58":1.3370712996,"59":0.0227173548,"60":3.4447185993,"61":1.8576486111,"62":0.8189754486,"63":0.6776530743,"64":0.0677763447,"65":0.212713033,"66":0.1003480032,"67":0.1746164262,"68":0.4123829603,"69":0.5507118702,"70":0.1047425047,"71":0.0194335245,"72":0.001482119,"73":0.0009310447,"74":0.0002176317,"75":0.0076908777,"76":0.0003866984,"77":0.0008008487,"78":1.2395234108,"79":0.4564163089,"80":0.0000461392},"patch":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":5,"7":5,"8":5,"9":5,"10":5,"11":5,"12":5,"13":5,"14":5,"15":5,"16":5,"17":5,"18":5,"19":5,"20":6,"21":6,"22":6,"23":7,"24":7,"25":7,"26":8,"27":9,"28":10,"29":10,"30":10,"31":11,"32":11,"33":11,"34":11,"35":11,"36":11,"37":11,"38":11,"39":11,"40":11,"41":11,"42":11,"43":12,"44":13,"45":14,"46":15,"47":16,"48":17,"49":17,"50":17,"51":17,"52":17,"53":18,"54":18,"55":18,"56":18,"57":18,"58":19,"59":20,"60":20,"61":21,"62":22,"63":22,"64":22,"65":22,"66":22,"67":22,"68":22,"69":22,"70":22,"71":22,"72":22,"73":22,"74":22,"75":22,"76":22,"77":22,"78":22,"79":22,"80":22},"start":{"0":1,"1":1,"2":1,"3":1,"4":1,"5":1,"6":0,"7":0,"8":0,"9":0,"10":0,"11":0,"12":0,"13":0,"14":0,"15":0,"16":0,"17":0,"18":0,"19":0,"20":1,"21":0,"22":0,"23":1,"24":0,"25":0,"26":1,"27":1,"28":1,"29":0,"30":0,"31":1,"32":0,"33":0,"34":0,"35":0,"36":0,"37":0,"38":0,"39":0,"40":0,"41":0,"42":0,"43":1,"44":1,"45":1,"46":1,"47":1,"48":1,"49":0,"50":0,"51":0,"52":0,"53":1,"54":0,"55":0,"56":0,"57":0,"58":1,"59":1,"60":0,"61":1,"62":1,"63":0,"64":0,"65":0,"66":0,"67":0,"68":0,"69":0,"70":0,"71":0,"72":0,"73":0,"74":0,"75":0,"76":0,"77":0,"78":0,"79":0,"80":0}} \ No newline at end of file diff --git a/plot_data/entropy_figure.json b/plot_data/entropy_figure.json new file mode 100644 index 0000000..79e42d9 --- /dev/null +++ b/plot_data/entropy_figure.json @@ -0,0 +1 @@ +{"text":"Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.","threshold":1.335442066192627,"dataframe_json":"{\"position\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":6,\"7\":7,\"8\":8,\"9\":9,\"10\":10,\"11\":11,\"12\":12,\"13\":13,\"14\":14,\"15\":15,\"16\":16,\"17\":17,\"18\":18,\"19\":19,\"20\":20,\"21\":21,\"22\":22,\"23\":23,\"24\":24,\"25\":25,\"26\":26,\"27\":27,\"28\":28,\"29\":29,\"30\":30,\"31\":31,\"32\":32,\"33\":33,\"34\":34,\"35\":35,\"36\":36,\"37\":37,\"38\":38,\"39\":39,\"40\":40,\"41\":41,\"42\":42,\"43\":43,\"44\":44,\"45\":45,\"46\":46,\"47\":47,\"48\":48,\"49\":49,\"50\":50,\"51\":51,\"52\":52,\"53\":53,\"54\":54,\"55\":55,\"56\":56,\"57\":57,\"58\":58,\"59\":59,\"60\":60,\"61\":61,\"62\":62,\"63\":63,\"64\":64,\"65\":65,\"66\":66,\"67\":67,\"68\":68,\"69\":69,\"70\":70,\"71\":71,\"72\":72,\"73\":73,\"74\":74,\"75\":75,\"76\":76,\"77\":77,\"78\":78,\"79\":79,\"80\":80},\"tokens\":{\"0\":\"<\",\"1\":\"D\",\"2\":\"a\",\"3\":\"e\",\"4\":\"n\",\"5\":\"e\",\"6\":\"r\",\"7\":\"y\",\"8\":\"s\",\"9\":\"_\",\"10\":\"T\",\"11\":\"a\",\"12\":\"r\",\"13\":\"g\",\"14\":\"a\",\"15\":\"r\",\"16\":\"y\",\"17\":\"e\",\"18\":\"n\",\"19\":\"_\",\"20\":\"i\",\"21\":\"s\",\"22\":\"_\",\"23\":\"i\",\"24\":\"n\",\"25\":\"_\",\"26\":\"G\",\"27\":\"a\",\"28\":\"m\",\"29\":\"e\",\"30\":\"_\",\"31\":\"o\",\"32\":\"f\",\"33\":\"_\",\"34\":\"T\",\"35\":\"h\",\"36\":\"r\",\"37\":\"o\",\"38\":\"n\",\"39\":\"e\",\"40\":\"s\",\"41\":\",\",\"42\":\"_\",\"43\":\"a\",\"44\":\"_\",\"45\":\"f\",\"46\":\"a\",\"47\":\"n\",\"48\":\"t\",\"49\":\"a\",\"50\":\"s\",\"51\":\"y\",\"52\":\"_\",\"53\":\"e\",\"54\":\"p\",\"55\":\"i\",\"56\":\"c\",\"57\":\"_\",\"58\":\"b\",\"59\":\"y\",\"60\":\"_\",\"61\":\"G\",\"62\":\"e\",\"63\":\"o\",\"64\":\"r\",\"65\":\"g\",\"66\":\"e\",\"67\":\"_\",\"68\":\"R\",\"69\":\".\",\"70\":\"R\",\"71\":\".\",\"72\":\"_\",\"73\":\"M\",\"74\":\"a\",\"75\":\"r\",\"76\":\"t\",\"77\":\"i\",\"78\":\"n\",\"79\":\".\",\"80\":\">\"},\"token_ids\":{\"0\":1,\"1\":72,\"2\":101,\"3\":105,\"4\":114,\"5\":105,\"6\":118,\"7\":125,\"8\":119,\"9\":36,\"10\":88,\"11\":101,\"12\":118,\"13\":107,\"14\":101,\"15\":118,\"16\":125,\"17\":105,\"18\":114,\"19\":36,\"20\":109,\"21\":119,\"22\":36,\"23\":109,\"24\":114,\"25\":36,\"26\":75,\"27\":101,\"28\":113,\"29\":105,\"30\":36,\"31\":115,\"32\":106,\"33\":36,\"34\":88,\"35\":108,\"36\":118,\"37\":115,\"38\":114,\"39\":105,\"40\":119,\"41\":48,\"42\":36,\"43\":101,\"44\":36,\"45\":106,\"46\":101,\"47\":114,\"48\":120,\"49\":101,\"50\":119,\"51\":125,\"52\":36,\"53\":105,\"54\":116,\"55\":109,\"56\":103,\"57\":36,\"58\":102,\"59\":125,\"60\":36,\"61\":75,\"62\":105,\"63\":115,\"64\":118,\"65\":107,\"66\":105,\"67\":36,\"68\":86,\"69\":50,\"70\":86,\"71\":50,\"72\":36,\"73\":81,\"74\":101,\"75\":118,\"76\":120,\"77\":109,\"78\":114,\"79\":50,\"80\":2},\"entropies\":{\"0\":3.3949158192,\"1\":2.1656451225,\"2\":2.3216569424,\"3\":2.8214058876,\"4\":1.5249242783,\"5\":0.0401624143,\"6\":0.0981037766,\"7\":0.0544578359,\"8\":0.3430138826,\"9\":1.0546212196,\"10\":0.25252828,\"11\":0.1494535804,\"12\":0.0624754503,\"13\":0.001355894,\"14\":0.0050173439,\"15\":0.0052358187,\"16\":0.0011725067,\"17\":0.0010307421,\"18\":1.0241208076,\"19\":3.6867966652,\"20\":0.4502205253,\"21\":0.0484119244,\"22\":2.2572875023,\"23\":0.3789347112,\"24\":1.0042934418,\"25\":2.9090054035,\"26\":1.8933598995,\"27\":1.3859074116,\"28\":0.3827198744,\"29\":0.2646365762,\"30\":1.7742085457,\"31\":0.0136727821,\"32\":0.0053820172,\"33\":0.5485631227,\"34\":0.2064044327,\"35\":0.0049266233,\"36\":0.0005439016,\"37\":0.0007023578,\"38\":0.0004170335,\"39\":0.0054524317,\"40\":1.1938130856,\"41\":0.0238215197,\"42\":3.1279797554,\"43\":1.3883389235,\"44\":3.0503094196,\"45\":1.695879817,\"46\":1.8551058769,\"47\":1.4570231438,\"48\":0.0047810897,\"49\":0.026396824,\"50\":0.6633765101,\"51\":0.3141393065,\"52\":2.8411159515,\"53\":1.143143177,\"54\":0.0520330966,\"55\":0.3398066461,\"56\":0.4140175879,\"57\":2.5563707352,\"58\":1.3370712996,\"59\":0.0227173548,\"60\":3.4447185993,\"61\":1.8576486111,\"62\":0.8189754486,\"63\":0.6776530743,\"64\":0.0677763447,\"65\":0.212713033,\"66\":0.1003480032,\"67\":0.1746164262,\"68\":0.4123829603,\"69\":0.5507118702,\"70\":0.1047425047,\"71\":0.0194335245,\"72\":0.001482119,\"73\":0.0009310447,\"74\":0.0002176317,\"75\":0.0076908777,\"76\":0.0003866984,\"77\":0.0008008487,\"78\":1.2395234108,\"79\":0.4564163089,\"80\":0.0000461392},\"patch\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":5,\"7\":5,\"8\":5,\"9\":5,\"10\":5,\"11\":5,\"12\":5,\"13\":5,\"14\":5,\"15\":5,\"16\":5,\"17\":5,\"18\":5,\"19\":5,\"20\":6,\"21\":6,\"22\":6,\"23\":7,\"24\":7,\"25\":7,\"26\":8,\"27\":9,\"28\":10,\"29\":10,\"30\":10,\"31\":11,\"32\":11,\"33\":11,\"34\":11,\"35\":11,\"36\":11,\"37\":11,\"38\":11,\"39\":11,\"40\":11,\"41\":11,\"42\":11,\"43\":12,\"44\":13,\"45\":14,\"46\":15,\"47\":16,\"48\":17,\"49\":17,\"50\":17,\"51\":17,\"52\":17,\"53\":18,\"54\":18,\"55\":18,\"56\":18,\"57\":18,\"58\":19,\"59\":20,\"60\":20,\"61\":21,\"62\":22,\"63\":22,\"64\":22,\"65\":22,\"66\":22,\"67\":22,\"68\":22,\"69\":22,\"70\":22,\"71\":22,\"72\":22,\"73\":22,\"74\":22,\"75\":22,\"76\":22,\"77\":22,\"78\":22,\"79\":22,\"80\":22},\"start\":{\"0\":1,\"1\":1,\"2\":1,\"3\":1,\"4\":1,\"5\":1,\"6\":0,\"7\":0,\"8\":0,\"9\":0,\"10\":0,\"11\":0,\"12\":0,\"13\":0,\"14\":0,\"15\":0,\"16\":0,\"17\":0,\"18\":0,\"19\":0,\"20\":1,\"21\":0,\"22\":0,\"23\":1,\"24\":0,\"25\":0,\"26\":1,\"27\":1,\"28\":1,\"29\":0,\"30\":0,\"31\":1,\"32\":0,\"33\":0,\"34\":0,\"35\":0,\"36\":0,\"37\":0,\"38\":0,\"39\":0,\"40\":0,\"41\":0,\"42\":0,\"43\":1,\"44\":1,\"45\":1,\"46\":1,\"47\":1,\"48\":1,\"49\":0,\"50\":0,\"51\":0,\"52\":0,\"53\":1,\"54\":0,\"55\":0,\"56\":0,\"57\":0,\"58\":1,\"59\":1,\"60\":0,\"61\":1,\"62\":1,\"63\":0,\"64\":0,\"65\":0,\"66\":0,\"67\":0,\"68\":0,\"69\":0,\"70\":0,\"71\":0,\"72\":0,\"73\":0,\"74\":0,\"75\":0,\"76\":0,\"77\":0,\"78\":0,\"79\":0,\"80\":0}}"} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e2ecd0d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[tool.isort] +profile = "black" +known_bytelatent = "bytelatent" +known_apps = "apps" +sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c6d87f1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +numpy +omegaconf +msgspec +rouge-score +sacrebleu +sentencepiece +tiktoken +fsspec +blobfile +wandb +viztracer +lm-eval +scipy +pynvml +datatrove +orjson +luigi +pydantic +altair +submitit +typer +rich diff --git a/setup/create_env.sh b/setup/create_env.sh new file mode 100644 index 0000000..9c7dad9 --- /dev/null +++ b/setup/create_env.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. + +#SBATCH --job-name=env_creation +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=0 +#SBATCH --time=01:00:00 + +# Exit immediately if a command exits with a non-zero status +set -e + +# Start timer +start_time=$(date +%s) + +# Get the current date +current_date=$(date +%y%m%d) + +# Create environment name with the current date +env_prefix=blt_$current_date + +# Create the conda environment + +source $CONDA_ROOT/etc/profile.d/conda.sh +conda create -n $env_prefix python=3.11 -y -c anaconda +conda activate $env_prefix + +echo "Currently in env $(which python)" + +# Install packages +pip install torch==2.5.0 xformers --index-url https://download.pytorch.org/whl/cu121 +pip install ninja +pip install --requirement requirements.txt + +# End timer +end_time=$(date +%s) + +# Calculate elapsed time in seconds +elapsed_time=$((end_time - start_time)) + +# Convert elapsed time to minutes +elapsed_minutes=$((elapsed_time / 60)) + +echo "Environment $env_prefix created and all packages installed successfully in $elapsed_minutes minutes!" diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py new file mode 100644 index 0000000..1aacf8f --- /dev/null +++ b/setup/download_prepare_hf_data.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import argparse +import os +import subprocess +import time + +import requests +from huggingface_hub import snapshot_download + + +def run_command(command): + print(f"Running: {command}") + subprocess.run(command, shell=True, check=True) + + +def download_dataset(repo_id, local_dir, allow_patterns): + print(f"Downloading dataset from {repo_id}...") + max_retries = 5 + retry_delay = 10 # seconds + for attempt in range(max_retries): + try: + snapshot_download( + repo_id, + repo_type="dataset", + local_dir=local_dir, + allow_patterns=allow_patterns, + resume_download=True, + max_workers=16, # Don't hesitate to increase this number to lower the download time + ) + break + except requests.exceptions.ReadTimeout: + if attempt < max_retries - 1: + print(f"Timeout occurred. Retrying in {retry_delay} seconds...") + time.sleep(retry_delay) + else: + raise + print(f"Dataset downloaded to {local_dir}") + + +def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): + from datatrove.executor import LocalPipelineExecutor + from datatrove.pipeline.readers import ParquetReader + from datatrove.pipeline.writers import JsonlWriter + + pipeline_exec = LocalPipelineExecutor( + pipeline=[ + ParquetReader( + src_dir, + file_progress=True, + doc_progress=True, + glob_pattern="**/*.parquet", + ), + JsonlWriter( + tgt_dir, + output_filename=dataset + ".chunk.${rank}.jsonl", + compression=None, + ), + ], + tasks=ntasks, + logging_dir=os.path.join(work_dir, "datatrove"), + ) + pipeline_exec.run() + + +def setup_terashuf(work_dir): + terashuf_dir = os.path.join(work_dir, "terashuf") + terashuf_executable = os.path.join(terashuf_dir, "terashuf") + + if os.path.exists(terashuf_executable): + print("terashuf executable already exists. Skipping setup.") + return terashuf_dir + + print("Setting up terashuf...") + run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}") + run_command(f"make -C {terashuf_dir}") + return terashuf_dir + + +def main(dataset, memory, data_dir, seed=42, nchunks=32): + # Configuration + repo_id = { + "fineweb_edu": "HuggingFaceFW/fineweb-edu", + "fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu", + "dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0", + "dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0", + }[dataset] + src_dir = f"{data_dir}/{dataset}" + out_dir = f"{src_dir}_shuffled" + os.makedirs(out_dir, exist_ok=True) + work_dir = src_dir # Directory of this Python file + prefix = f"{dataset}.chunk." + orig_extension = { + "fineweb_edu": ".jsonl", + "fineweb_edu_10bt": ".jsonl", + "dclm_baseline_1.0": ".jsonl.zst", + "dclm_baseline_1.0_10prct": ".jsonl.zst", + }[dataset] + cat_command = { + "fineweb_edu": "cat", + "fineweb_edu_10bt": "cat", + "dclm_baseline_1.0": "zstdcat", + "dclm_baseline_1.0_10prct": "zstdcat", + }[dataset] + allow_patterns = { + "fineweb_edu": None, + "fineweb_edu_10bt": "sample/10BT/*", + "dclm_baseline_1.0": "*.jsonl.zst", + "dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst", + }[dataset] + suffix = ".jsonl" + k_validation = 10000 # Number of lines to take from each chunk for validation + + # Setup terashuf + terashuf_dir = setup_terashuf(work_dir) + + # Download dataset + download_dataset(repo_id, src_dir, allow_patterns) + + if "fineweb" in dataset: + parquet_to_jsonl(dataset, work_dir, src_dir, src_dir) + + # Set up environment variables + os.environ["MEMORY"] = f"{memory}" + os.environ["SEED"] = f"{seed}" + + # Run the original shuffling and splitting command + terashuf_executable = os.path.join(terashuf_dir, "terashuf") + run_command( + f"ulimit -n 100000 && " + f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | " + f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}" + "; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;" + ) + + # Create validation set and remove lines from chunks + validation_file = f"{out_dir}/{dataset}.val{suffix}" + for i in range(nchunks): + chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}" + run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}") + run_command(f"sed -i '1,{k_validation}d' {chunk_file}") + + print("All tasks completed successfully!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset", type=str) + parser.add_argument("memory", type=float, default=8) + parser.add_argument("--data_dir", type=str, default="data") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--nchunks", type=int, default=32) + + args = parser.parse_args() + + main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks) diff --git a/setup/download_tokenizer.py b/setup/download_tokenizer.py new file mode 100644 index 0000000..fc9c8d5 --- /dev/null +++ b/setup/download_tokenizer.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import argparse +import os +from typing import Optional + +from requests.exceptions import HTTPError + +TOKENIZER = { + "llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"), + "llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"), + "gemma": ("google/gemma-2-9b", "tokenizer.model"), +} + + +def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None): + if tokenizer_name in TOKENIZER: + repo_id, filename = TOKENIZER[tokenizer_name] + + from huggingface_hub import hf_hub_download + + try: + hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=path_to_save, + local_dir_use_symlinks=False, + token=api_key if api_key else None, + ) + except HTTPError as e: + if e.response.status_code == 401: + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) + else: + raise e + else: + from tiktoken import get_encoding + + if "TIKTOKEN_CACHE_DIR" not in os.environ: + os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save + try: + get_encoding(tokenizer_name) + except ValueError: + print( + f"Tokenizer {tokenizer_name} not found. Please check the name and try again." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("tokenizer_name", type=str) + parser.add_argument("tokenizer_dir", type=str, default=8) + parser.add_argument("--api_key", type=str, default="") + args = parser.parse_args() + + main( + tokenizer_name=args.tokenizer_name, + path_to_save=args.tokenizer_dir, + api_key=args.api_key, + )