From bcc039bb75aec70385fd5c148497d4ae86b526b5 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 12 Dec 2024 15:32:30 -0800 Subject: [PATCH] Initial commit --- .github/workflows/black.yml | 12 + .github/workflows/isort.yml | 10 + .gitignore | 168 +++ .prettierrc | 8 + CODE_OF_CONDUCT.md | 80 ++ CONTRIBUTING.md | 36 + LICENSE | 28 + README.md | 117 ++ apps/__init__.py | 0 apps/main/__init__.py | 0 apps/main/configs/eval.yaml | 35 + apps/main/configs/llama_1B.yaml | 87 ++ apps/main/configs/llama_7B.yaml | 95 ++ apps/main/eval.py | 354 ++++++ apps/main/generate.py | 463 +++++++ apps/main/lingua_train.py | 654 ++++++++++ blt-figure.jpg | Bin 0 -> 54944 bytes blt-figure.pdf | Bin 0 -> 62517 bytes bytelatent/.DS_Store | Bin 0 -> 6148 bytes bytelatent/__init__.py | 3 + bytelatent/args.py | 199 +++ bytelatent/base_transformer.py | 585 +++++++++ bytelatent/checkpoint.py | 311 +++++ bytelatent/configs/debug.yaml | 110 ++ bytelatent/constants.py | 5 + bytelatent/data/__init__.py | 1 + bytelatent/data/data_types.py | 115 ++ bytelatent/data/iterators/__init__.py | 1 + .../data/iterators/abstract_iterator.py | 23 + bytelatent/data/iterators/arrow_iterator.py | 216 ++++ bytelatent/data/iterators/looping_iterator.py | 36 + .../data/iterators/multiprocess_iterator.py | 243 ++++ bytelatent/data/iterators/packing_iterator.py | 226 ++++ .../data/iterators/preprocess_iterator.py | 111 ++ .../data/iterators/sampling_iterator.py | 66 + .../data/iterators/sequence_iterator.py | 122 ++ .../data/iterators/test_arrow_iterator.py | 89 ++ bytelatent/data/iterators/test_iters.py | 162 +++ bytelatent/data/ngram_processor.py | 146 +++ bytelatent/data/patcher.py | 609 ++++++++++ bytelatent/distributed.py | 478 ++++++++ bytelatent/entropy_model.py | 36 + bytelatent/float8.py | 152 +++ bytelatent/logger.py | 129 ++ bytelatent/metrics.py | 232 ++++ bytelatent/model/__init__.py | 1 + bytelatent/model/blt.py | 1064 +++++++++++++++++ bytelatent/model/local_models.py | 356 ++++++ bytelatent/model/transformer.py | 199 +++ bytelatent/model/utils.py | 116 ++ bytelatent/optim.py | 162 +++ bytelatent/plotting/__init__.py | 1 + .../plotting/config_entropy_figure.yaml | 3 + .../plotting/config_scaling_figures.yaml | 4 + bytelatent/plotting/entropy_figure.py | 85 ++ bytelatent/plotting/scaling_figures.py | 108 ++ bytelatent/preprocess/__init__.py | 1 + bytelatent/preprocess/data_pipeline.py | 74 ++ bytelatent/preprocess/parallel_entropies.py | 108 ++ bytelatent/preprocess/preprocess_entropies.py | 141 +++ bytelatent/probe.py | 694 +++++++++++ bytelatent/profiling.py | 133 +++ bytelatent/stool.py | 237 ++++ bytelatent/test_blt.py | 471 ++++++++ bytelatent/test_entropy_model.py | 55 + bytelatent/tokenizers/__init__.py | 1 + bytelatent/tokenizers/abstract_tokenizer.py | 19 + bytelatent/tokenizers/blt_tokenizer.py | 150 +++ bytelatent/tokenizers/build_tokenizer.py | 69 ++ bytelatent/tokenizers/byte_tokenizer.py | 35 + bytelatent/tokenizers/constants.py | 12 + .../tokenizers/sentence_piece_tokenizer.py | 59 + bytelatent/tokenizers/test_blt_tokenizer.py | 41 + bytelatent/tokenizers/tiktoken_tokenizer.py | 86 ++ bytelatent/train.py | 651 ++++++++++ bytelatent/transformer.py | 216 ++++ dev/lint.sh | 3 + fixtures/tokenizer_data.json | 1 + fixtures/tokenizer_data_bpe_delim.json | 1 + fixtures/tokens_with_entropies.json | 1 + plot_data/entropy_figure.json | 1 + pyproject.toml | 5 + requirements.txt | 22 + setup/create_env.sh | 48 + setup/download_prepare_hf_data.py | 156 +++ setup/download_tokenizer.py | 60 + 86 files changed, 12203 insertions(+) create mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/isort.yml create mode 100644 .gitignore create mode 100644 .prettierrc create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 apps/__init__.py create mode 100644 apps/main/__init__.py create mode 100644 apps/main/configs/eval.yaml create mode 100644 apps/main/configs/llama_1B.yaml create mode 100644 apps/main/configs/llama_7B.yaml create mode 100644 apps/main/eval.py create mode 100644 apps/main/generate.py create mode 100644 apps/main/lingua_train.py create mode 100644 blt-figure.jpg create mode 100644 blt-figure.pdf create mode 100644 bytelatent/.DS_Store create mode 100644 bytelatent/__init__.py create mode 100644 bytelatent/args.py create mode 100644 bytelatent/base_transformer.py create mode 100644 bytelatent/checkpoint.py create mode 100644 bytelatent/configs/debug.yaml create mode 100644 bytelatent/constants.py create mode 100644 bytelatent/data/__init__.py create mode 100644 bytelatent/data/data_types.py create mode 100644 bytelatent/data/iterators/__init__.py create mode 100644 bytelatent/data/iterators/abstract_iterator.py create mode 100644 bytelatent/data/iterators/arrow_iterator.py create mode 100644 bytelatent/data/iterators/looping_iterator.py create mode 100644 bytelatent/data/iterators/multiprocess_iterator.py create mode 100644 bytelatent/data/iterators/packing_iterator.py create mode 100644 bytelatent/data/iterators/preprocess_iterator.py create mode 100644 bytelatent/data/iterators/sampling_iterator.py create mode 100644 bytelatent/data/iterators/sequence_iterator.py create mode 100644 bytelatent/data/iterators/test_arrow_iterator.py create mode 100644 bytelatent/data/iterators/test_iters.py create mode 100644 bytelatent/data/ngram_processor.py create mode 100644 bytelatent/data/patcher.py create mode 100644 bytelatent/distributed.py create mode 100644 bytelatent/entropy_model.py create mode 100644 bytelatent/float8.py create mode 100644 bytelatent/logger.py create mode 100644 bytelatent/metrics.py create mode 100644 bytelatent/model/__init__.py create mode 100644 bytelatent/model/blt.py create mode 100644 bytelatent/model/local_models.py create mode 100644 bytelatent/model/transformer.py create mode 100644 bytelatent/model/utils.py create mode 100644 bytelatent/optim.py create mode 100644 bytelatent/plotting/__init__.py create mode 100644 bytelatent/plotting/config_entropy_figure.yaml create mode 100644 bytelatent/plotting/config_scaling_figures.yaml create mode 100644 bytelatent/plotting/entropy_figure.py create mode 100644 bytelatent/plotting/scaling_figures.py create mode 100644 bytelatent/preprocess/__init__.py create mode 100644 bytelatent/preprocess/data_pipeline.py create mode 100644 bytelatent/preprocess/parallel_entropies.py create mode 100644 bytelatent/preprocess/preprocess_entropies.py create mode 100644 bytelatent/probe.py create mode 100644 bytelatent/profiling.py create mode 100644 bytelatent/stool.py create mode 100644 bytelatent/test_blt.py create mode 100644 bytelatent/test_entropy_model.py create mode 100644 bytelatent/tokenizers/__init__.py create mode 100644 bytelatent/tokenizers/abstract_tokenizer.py create mode 100644 bytelatent/tokenizers/blt_tokenizer.py create mode 100644 bytelatent/tokenizers/build_tokenizer.py create mode 100644 bytelatent/tokenizers/byte_tokenizer.py create mode 100644 bytelatent/tokenizers/constants.py create mode 100644 bytelatent/tokenizers/sentence_piece_tokenizer.py create mode 100644 bytelatent/tokenizers/test_blt_tokenizer.py create mode 100644 bytelatent/tokenizers/tiktoken_tokenizer.py create mode 100644 bytelatent/train.py create mode 100644 bytelatent/transformer.py create mode 100644 dev/lint.sh create mode 100644 fixtures/tokenizer_data.json create mode 100644 fixtures/tokenizer_data_bpe_delim.json create mode 100644 fixtures/tokens_with_entropies.json create mode 100644 plot_data/entropy_figure.json create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 setup/create_env.sh create mode 100644 setup/download_prepare_hf_data.py create mode 100644 setup/download_tokenizer.py diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..433932f --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,12 @@ +name: Lint with Black + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: + version: "24.8.0" diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml new file mode 100644 index 0000000..16f1abf --- /dev/null +++ b/.github/workflows/isort.yml @@ -0,0 +1,10 @@ +name: Lint with isort + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: isort/isort-action@master diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..56891a9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,168 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +*.out + +figures/ +.vscode/ +.DS_Store + diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..ae6f541 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,8 @@ +{ + "overrides": [ + { + "files": "*.yaml", + "options": { "tabWidth": 2 } + } + ] +} diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..3232ed6 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..88d3ab0 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,36 @@ +# Contributing to + +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests + +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") + +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues + +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License + +By contributing to BLT, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bc559a9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright 2024 Meta + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice,this list +of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this +list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may +be used to endorse or promote products derived from this software without specific +prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT +SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..689c8e4 --- /dev/null +++ b/README.md @@ -0,0 +1,117 @@ +# Byte Latent Transformer + +This repository contains code for our paper: "Byte Latent Transformer: Patches Scale Better Than Tokens" + +- [Paper Link](https://dl.fbaipublicfiles.com/blt/BLT__Patches_Scale_Better_Than_Tokens.pdf) + +## Abstract + +We introduce the Byte Latent Transformer architecture (BLTs), a new byte-level LLM architecture that +for the first time, matches tokenization-based LLM performance at scale, with significant improvements +in inference efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve +as the primary units of computation. Patches are segmented dynamically based on the entropy of the +next byte, allocating more compute and model capacity where there is more data complexity. The BLT +architecture includes new attention mechanisms to maximize the information flow between byte and +patch hidden representations and a new type of byte-sequence memory. We present the first scaling +study of byte-level models up to 8B parameters and 8T training bytes, showing for the first time +that we can train a model end-to-end at scale from bytes with no tokenization or other preprocessing. +Scaling trends reveal training and inference efficiency benefits from dynamically selecting very long +patches on average, along with qualitative improvements with reasoning and long tail generalization +from modeling byte-sequences. + +![BLT Architecture Diagram](blt-figure.jpg) + +## Development Status + +We are actively updating the blt code to make it easier to reproduce our results. +Please file an issue and/or be patient while we make more of our code public! + +## Quick start + +The following commands launch a SLURM job that creates an environment for Meta Lingua. +The env creation should take around 5 minutes without counting downloads. + +```bash +git clone https://github.com/facebookresearch/blt +cd blt + +bash setup/create_env.sh +# or if you have access to a SLURM cluster +sbatch setup/create_env.sh +``` + +Once that is done your can activate the environment + +```bash +conda activate blt_ +``` + +use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`). +This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details. + +```bash +python setup/download_prepare_hf_data.py fineweb_edu --data_dir ./data --seed 42 --nchunks +``` + +to download tokenizer (here llama3), use the folowing script: + +```bash +python setup/download_tokenizer.py llama3 --api_key +``` + +Now launch a debug job to check if everything works. **The provided configurations are templates, you need to adapt them for them to work (change `dump_dir`, `data.root_dir`, `data.tokenizer.path`, etc ...)** + +```bash +# stool stands for SLURM tool ! +python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition= +# if you want to launch locally you can use torchrun +torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml +# or you can also launch on 1 GPU +python -m bytelatent.train config=apps/bytelatent/configs/debug.yaml +``` + +When using `stool`, if a job crashes, it can be relaunched using sbatch: + +```bash +sbatch path/to/dump_dir/submit.slurm +``` + +## Linting + +To lint, run the following command + +``` +bash dev/lint.sh +``` + +## Citation + +The BLT is partially based on Meta Lingua, so consider citing it in addition to our BLT paper if you re-use our work. + +BLT Paper Citation (will be updated to arXiv soon) + +``` +@article{meta_blt, + author = {Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer}, + title = {Byte Latent Transformer: Patches Scale Better Than Tokens}, + url = {https://github.com/facebookresearch/blt}, + year = {2024} +} +``` + +Lingua Code + +``` +@misc{meta_lingua, + author = {Mathurin Videau, Badr Youbi Idrissi, Daniel Haziza, Luca Wehrstedt, Jade Copet, Olivier Teytaud, David Lopez-Paz}, + title = {{Meta Lingua}: A minimal {PyTorch LLM} training library}, + url = {https://github.com/facebookresearch/lingua}, + year = {2024} +} +``` + +## License + +The BLT code is partially based on Meta Lingia. + +Meta Lingua is licensed under BSD-3-Clause license. Refer to the LICENSE file in the top level directory. diff --git a/apps/__init__.py b/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/main/__init__.py b/apps/main/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/main/configs/eval.yaml b/apps/main/configs/eval.yaml new file mode 100644 index 0000000..4b52ba0 --- /dev/null +++ b/apps/main/configs/eval.yaml @@ -0,0 +1,35 @@ +name: "debug_evals" +# ckpt_dir: !!CHANGETHIS!! +# dump_dir: !!CHANGETHIS!! +generator: + max_tokens: 8192 + dtype: bf16 + temperature: 1.0 + top_p: 0.95 +harness: + tasks: + - hellaswag + - task: boolq + dataset_kwargs: + trust_remote_code: true + - task: nq_open + num_fewshot: 5 + - piqa + - task: social_iqa + dataset_kwargs: + trust_remote_code: true + - triviaqa + - winogrande + - openbookqa + - arc_easy + - arc_challenge + - race + - commonsense_qa + # - coqa + - copa + - gsm8k + - bbh + - mmlu + - mmlu_pro +validation: + max_steps: 1000 diff --git a/apps/main/configs/llama_1B.yaml b/apps/main/configs/llama_1B.yaml new file mode 100644 index 0000000..786d848 --- /dev/null +++ b/apps/main/configs/llama_1B.yaml @@ -0,0 +1,87 @@ +# dump_dir: !!!CHANGE_THIS!!! +name: large_lm +steps: 60_000 +probe_freq: null +seed: 777 + +optim: + lr: 3e-3 + weight_decay: 0.033 + warmup: 5000 + lr_min_ratio: 0.000001 + clip: 1.0 + +distributed: + fsdp_type: full_shard + compile: true + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +model: + dim: 2048 + n_layers: 25 + n_heads: 16 + +data: + root_dir: data/shuffled + sources: + dclm_baseline_1.0: 100.0 + batch_size: 4 + prefetch_size: 1024 + seq_len: 4096 + n_views: 2 + load_async: true + add_bos: true + add_eos: true + tokenizer: + name: tiktoken + path: tokenizers/cl_toplang_128k.tiktoken + +profiling: + run: true + mem_warmup: 0 + mem_steps: 4 + profile_warmup: 100 + profile_steps: 4 + +checkpoint: + dump: + every: 2500 + keep: 3 + eval: + every: 5000 + keep: -1 + +logging: + freq: 1 + +async_eval_gpus: 8 +eval: + harness: + tasks: + - hellaswag + - task: boolq + dataset_kwargs: + trust_remote_code: true + - piqa + - task: social_iqa + dataset_kwargs: + trust_remote_code: true + - winogrande + - openbookqa + - arc_easy + - arc_challenge + - race + - commonsense_qa + - copa + # - coqa + # - task: nq_open + # num_fewshot: 5 + # - triviaqa + validation: + max_steps: 1000 + generator: + max_tokens: 16384 + dtype: bf16 diff --git a/apps/main/configs/llama_7B.yaml b/apps/main/configs/llama_7B.yaml new file mode 100644 index 0000000..4461dd9 --- /dev/null +++ b/apps/main/configs/llama_7B.yaml @@ -0,0 +1,95 @@ +#python -m lingua.stool config=apps/main/configs/llama2_7B.yaml nodes=32 account=fair_amaia_cw_codegen qos=lowest +# dump_dir: !!!CHANGE_THIS!!! +name: "7b_baseline" +steps: 100_000 +grad_acc_steps: 1 +probe_freq: 100 + +seed: 777 +optim: + lr: 1.0e-3 + weight_decay: 0.1 + warmup: 2000 + lr_min_ratio: 0.000001 + clip: 1.0 + +distributed: + fsdp_type: full_shard + compile: true + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +model: + dim: 4096 + n_layers: 32 + n_heads: 32 + rope_theta: 100_000 + ffn_dim_multiplier: 1.0 + multiple_of: 256 + +data: + root_dir: data/shuffled + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 1024 + seq_len: 4096 + n_views: 2 + load_async: true + tokenizer: + name: tiktoken + path: tokenizers/cl_toplang_128k.tiktoken + +profiling: + run: true + mem_warmup: 0 + mem_steps: 4 + profile_warmup: 100 + profile_steps: 4 + +checkpoint: + dump: + every: 10000 + keep: -1 + eval: + every: 1000 + keep: 3 + +logging: + freq: 1 + +async_eval_gpus: 8 +eval: + dataset_dir: datasets/eval + harness: + tasks: + - hellaswag + - task: boolq + dataset_kwargs: + trust_remote_code: true + - piqa + - task: social_iqa + dataset_kwargs: + trust_remote_code: true + - winogrande + - openbookqa + - arc_easy + - arc_challenge + - race + - commonsense_qa + # - coqa + - copa + - mmlu + - mmlu_pro + # - task: nq_open + # num_fewshot: 5 + # - triviaqa + # - gsm8k + # - bbh + validation: + max_steps: 1000 + generator: + max_tokens: 8192 + dtype: bf16 diff --git a/apps/main/eval.py b/apps/main/eval.py new file mode 100644 index 0000000..ed20f49 --- /dev/null +++ b/apps/main/eval.py @@ -0,0 +1,354 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +import logging +import os +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import torch +from lingua.args import dump_config +from lingua.data import init_choice_state, setup_sources +from lm_eval import simple_evaluate +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from omegaconf import OmegaConf + +from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.distributed import ( + DistributedArgs, + dist_mean_dict, + get_global_rank, + get_world_size, + setup_torch_distributed, +) +from bytelatent.transformer import LMTransformer, LMTransformerArgs + +from apps.main.generate import ( + PackedCausalTransformerGenerator, + PackedCausalTransformerGeneratorArgs, + load_consolidated_model_and_tokenizer, +) + +EVAL_FOLDER_NAME = "{:010d}" + +logger = logging.getLogger() + + +@dataclass +class LMHarnessArgs: + tasks: Optional[List[Any]] = None + num_fewshot: Optional[int] = None + device: Optional[str] = None + use_cache: Optional[str] = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: Optional[Union[int, float]] = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: Optional[str] = None + apply_chat_template: Union[bool, str] = False + fewshot_as_multiturn: bool = False + gen_kwargs: Optional[str] = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +@dataclass +class ValidationArgs: + max_steps: Optional[int] = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) + use_val_from_train_src: bool = True # Use the validation set from training sources + root_dir: str = "" + sources: List[str] = field(default_factory=list) # Other sources to eval on + + +@dataclass +class EvalArgs: + name: str = "evals" + dump_dir: Optional[str] = None + metric_log_dir: Optional[str] = None + ckpt_dir: str = "" + generator: PackedCausalTransformerGeneratorArgs = field( + default_factory=PackedCausalTransformerGeneratorArgs + ) + harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) + validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs) + + wandb: Optional[Any] = None + + global_step: Optional[int] = None # for in-training evaluation + + +def all_dicts_same(dict_list): + if not dict_list: # Check if the list is empty + return True + + # Compare each dictionary to the first one + first_dict = dict_list[0] + return all(d == first_dict for d in dict_list) + + +class MockAccelerator: + def gather(self, tensor): + l = [torch.zeros_like(tensor) for _ in range(get_world_size())] + torch.distributed.all_gather(l, tensor) + return torch.stack(l) + + def wait_for_everyone(self): + torch.distributed.barrier() + + +# Light wrapper around generator for lm-eval harness +class EvalHarnessLM(LM): + def __init__(self, generator): + super().__init__() + self.generator = generator + self.accelerator = MockAccelerator() + self._rank = get_global_rank() + self._world_size = get_world_size() + self.device = generator.device + + def generate_until(self, requests: List[Instance]) -> List[str]: + prompts, gen_args = zip(*[req.args for req in requests]) + assert all_dicts_same(gen_args), "Doesn't support different gen args for now" + gen_args = gen_args[0] + temperature = gen_args.get("temperature", 0.0) + top_p = gen_args.get("top_p", None) + top_k = gen_args.get("top_k", None) + until = gen_args.get("until", []) + + self.generator.temperature = temperature + self.generator.top_p = top_p + self.generator.top_k = top_k + self.generator.until = until + generations, _, _ = self.generator.generate(prompts) + filtered_gen = [] + for g in generations: + for e in until: + g = g.replace(e, "") + filtered_gen.append(g) + return filtered_gen + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + prompts, continuations = zip(*[req.args for req in requests]) + inputs = [req.args[0] + req.args[1] for req in requests] + max_gen_len = self.generator.max_gen_len + # We temporarily lower max gen len + self.generator.max_gen_len = 1 + _, lls, greedy = self.generator.generate(inputs) + results = [] + for p, ll, gr in zip(prompts, lls, greedy): + p_len = len( + self.generator.tokenizer.encode(p, add_bos=False, add_eos=False) + ) + results.append((ll[p_len:].sum().item(), gr[p_len:].all().item())) + + self.generator.max_gen_len = max_gen_len + return results + + def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + prompts = [req.args[0] for req in requests] + max_gen_len = self.generator.max_gen_len + # We temporarily lower max gen len + self.generator.max_gen_len = 1 + _, lls, _ = self.generator.generate(prompts) + results = [] + for ll in lls: + results.append((ll.sum().item(),)) + self.generator.max_gen_len = max_gen_len + + return results + + +def eval_on_val(generator, val_args: ValidationArgs, train_cfg): + srcs = {} + for src in val_args.sources: + path = os.path.join(val_args.root_dir, src) + srcs[path] = 1.0 + for src in train_cfg.data.sources: + path = os.path.join(train_cfg.data.root_dir, src) + srcs[path] = 1.0 + + multi_state = init_choice_state( + "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" + ) + path_to_iter = setup_sources(multi_state) + + max_gen_len = generator.max_gen_len + # We temporarily lower max gen len + generator.max_gen_len = 1 + + all_val_metrics = {} + for src in path_to_iter: + jsonl_iterator = path_to_iter[src] + texts = [] + logger.info(f"Running validation on {src}...") + for step, (content, state) in enumerate(jsonl_iterator): + if state["current_iter"] > 0 or ( + val_args.max_steps is not None and step >= val_args.max_steps + ): + break + content_key = "text" if ("text" in content) else "content" + texts.append(content[content_key]) + + _, loglikelihood, _ = generator.generate(texts) + + metrics = defaultdict(list) + for i, ll in enumerate(loglikelihood): + tmp = ll.sum().item() + metrics["nll"].append(tmp) + metrics["nll_per_token"].append(tmp / len(ll)) + metrics["nll_per_char"].append(tmp / len(texts[i])) + + metrics["avg_seqlen"].append(len(ll)) + + for m in metrics: + metrics[m] = sum(metrics[m]) / len(metrics[m]) + metrics.update(dist_mean_dict(metrics)) + logger.info(f"Validation on {src} done. Metrics: {metrics}") + + name = os.path.basename(src) + if name in all_val_metrics: + logger.warning( + f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1" + ) + name = f"{name}_1" + all_val_metrics[name] = metrics + + generator.max_gen_len = max_gen_len + + return all_val_metrics + + +def launch_eval(cfg: EvalArgs): + if not torch.distributed.is_initialized(): + setup_torch_distributed(DistributedArgs()) + if ( + Path(cfg.ckpt_dir).exists() + and (Path(cfg.ckpt_dir) / "params.json").exists() + and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + ): + consolidate_path = Path(cfg.ckpt_dir) + else: + consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER + if not consolidate_path.exists() and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + + Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) + dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + + consolidate_path = str(consolidate_path) + torch.distributed.barrier() + logger.info("Loading model") + model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( + consolidate_path, + model_cls=LMTransformer, + model_args_cls=LMTransformerArgs, + ) + logger.info("Model loaded") + model.eval() + generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + + wrap = EvalHarnessLM(generator) + results = simple_evaluate(wrap, **asdict(cfg.harness)) + val_results = None + if cfg.validation: + val_results = eval_on_val(generator, cfg.validation, train_cfg) + if get_global_rank() == 0: + with open(Path(cfg.dump_dir) / "results.json", "w") as f: + f.write(json.dumps(results)) + logger.info(f"All evaluation results: {results['results']}") + if val_results is not None: + with open(Path(cfg.dump_dir) / "validation.json", "w") as f: + f.write(json.dumps(val_results)) + logger.info(f"All validation results: {val_results}") + if cfg.metric_log_dir and get_global_rank() == 0: + metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + + logger.info(f"Writing metric logs to {metric_log_path}") + timestamp = { + "created_at": datetime.utcnow().isoformat(), + } + if cfg.global_step is not None: + timestamp["global_step"] = cfg.global_step + print( + json.dumps(timestamp | results["results"]), + file=open(metric_log_path, mode="a"), + flush=True, + ) + + val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" + if val_results is not None: + print( + json.dumps(timestamp | val_results), + file=open(val_log_path, mode="a"), + flush=True, + ) + + del generator + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate EvalArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call eval.py with eval.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in EvalArgs dataclass. + """ + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.structured(EvalArgs()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_object(cfg) + launch_eval(cfg) + + +if __name__ == "__main__": + main() diff --git a/apps/main/generate.py b/apps/main/generate.py new file mode 100644 index 0000000..a3a8627 --- /dev/null +++ b/apps/main/generate.py @@ -0,0 +1,463 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + +import torch +from lingua.args import dataclass_from_dict +from lingua.tokenizers.abstract_tokenizer import Tokenizer +from lingua.tokenizers.build_tokenizer import build_tokenizer +from omegaconf import OmegaConf +from torch import nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import create_block_mask +from tqdm import tqdm + +from bytelatent.base_transformer import ( + Attention, + causal_mask, + generate_doc_mask_mod, + lengths_to_local_ids, + lengths_to_start_ids, +) +from bytelatent.checkpoint import CONSOLIDATE_NAME +from bytelatent.transformer import LMTransformer, LMTransformerArgs + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +def sample_top_k(probs, k): + topk_value, _ = torch.topk(probs, k) # batch_sz x topk + min_value_top_k = topk_value[:, [-1]] + probs[probs < min_value_top_k] = 0.0 + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs, num_samples=1) + return next_token + + +def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): + shape = logits.shape + logits = logits.flatten(end_dim=-2) + if temperature > 0.0: + probs = torch.softmax(logits / temperature, dim=-1) + + if top_p is not None: + next_token = sample_top_p(probs, top_p) + elif top_k is not None: + next_token = sample_top_k(probs, top_k) + else: + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1) + return next_token.view(shape[:-1]) + + +def pack_prompts(prompts: List[int]): + res = [] + lengths = [] + for i, p in enumerate(prompts): + p = torch.tensor(p, dtype=torch.long) + l = p.size(0) + res.append(p) + lengths.append(l) + lengths = torch.tensor(lengths, dtype=torch.long) + res = torch.cat(res) + return res, lengths + + +def batch_prompts(prompts, max_elements, lengths=None): + batches = [] + current_batch = [] + current_count = 0 + + for i in range(len(prompts)): + prt = prompts[i] + prompt_size = len(prt) if lengths is None else lengths[i] + if current_count + prompt_size <= max_elements: + current_batch.append(prt) + current_count += prompt_size + else: + if current_batch: # Add the current batch to batches + batches.append(current_batch) + # Start a new batch with the current prompt + current_batch = [prt] + current_count = prompt_size + + # Add the last batch if it contains any prompts + if current_batch: + batches.append(current_batch) + + return batches + + +class KVCache(nn.Module): + def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device): + super().__init__() + shape = (bsz, seqlen, n_heads, head_dim) + self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device)) + self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device)) + self.offset = 0 + + def reset(self): + self.k_cache.zero_() + self.v_cache.zero_() + self.offset = 0 + + def update(self, k_val, v_val, tok_idx): + # input_pos: [B], k_val: [B, S, H, D] + self.k_cache.index_copy_(1, self.offset + tok_idx, k_val) + self.v_cache.index_copy_(1, self.offset + tok_idx, v_val) + return self.k_cache, self.v_cache + + +@dataclass +class PackedCausalTransformerGeneratorArgs: + temperature: float = 0.0 + top_p: Optional[float] = None + top_k: Optional[float] = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: Optional[int] = None + until: List[str] = field(default_factory=list) + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: Optional[str] = "bf16" + device: Optional[str] = "cuda" + + +class PackedCausalTransformerGenerator: + def __init__( + self, + cfg: PackedCausalTransformerGeneratorArgs, + model: nn.Module, + tokenizer: Tokenizer, + ): + """ + This class wraps a causal transformer model with its corresponding tokenizer + and provides an efficient way to pack prompts together and do generation on + the packed sequence. + + For example, if we had the prompts "Hello, I am a " and "Initiating calibration " + Then this class will concatenate those sequence (pack them together) + "Hello, I am a Initiating calibration" + And make the necessary attention masks such that a sequence only attends to itself + during prefilling and generation. + + This class creates a fixed size cache of size max_tokens or sum of prompt sizes + + the max number of generated tokens per sequence. + """ + self.model = model + self.tokenizer = tokenizer + self.temperature = cfg.temperature + self.top_p = cfg.top_p + self.top_k = cfg.top_k + + self.max_gen_len = cfg.max_gen_len + self.max_tokens = cfg.max_tokens + self.max_prompt_len = cfg.max_prompt_len + self.until = cfg.until + self.max_until_size = max([len(e) for e in self.until]) if self.until else 1 + self.device = cfg.device + + # Compile if necessary + self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling) + self.generate_next_token = torch.compile( + self.generate_next_token, + mode="reduce-overhead", + disable=not cfg.reduce_generation_overhead, + ) + + self.show_progress = cfg.show_progress + self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype] + + self.prefill_doc_id, self.prefill_tok_id = None, None + self.padded_doc_id, self.padded_tok_id = None, None + self.current_doc_id, self.current_tok_id = None, None + self.padded_doc_start = None + self.prefill_mask = None + + def clear_cache(self, offset): + for module in self.model.modules(): + if isinstance(module, Attention): + if not hasattr(module, "kv_cache"): + module.kv_cache = KVCache( + 1, + self.max_tokens, + module.n_kv_heads, + module.head_dim, + self.dtype, + self.device, + ) + module.kv_cache.offset = offset + + @torch.compiler.disable + def setup_prefilling(self, lengths: torch.Tensor): + # The KV cache is a fixed size tensor of size max_tokens that we need + # to update in order to do correct autoregressive generation. + + # Here we will generate token by token but on multiple sequences + # at once. To do so, we need to have an attention mask that makes + # each sequence independent. + + # Each sequence will write to its allocated space in the KV Cache. + # We allocate len(seq) + max_gen_len to each sequence in the cache. + + # We will generate max_gen_len for each document + padded_lengths = lengths + self.max_gen_len + max_tokens = self.max_tokens or padded_lengths.sum().item() + # The last document might have more padding to fill up to max_tokens + padded_lengths[-1] += max_tokens - padded_lengths.sum() + + # This is the start index in the cache for each document + self.padded_doc_start = lengths_to_start_ids(padded_lengths) + # For example with ab--123--cdef-- + # this would be 0, 4, 9 if max_gen_len is 2 + + # We repeat interleave to align with tokens for prefilling + # Ex: ab--123--cdef-- + # 000044444999999 + prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths) + # This offset will make sure the tokens are written to the + # correct positions in the cache during prefilling + + # We either init the cache or clear it by resetting the offset to prefill_offset + self.clear_cache(prefill_offset) + + # The prefilling mask looks like the following for + # the two packed sequences ab and 123 : ab123 + # Where spaces are empty cache positions + # keys + # ab---123--- + # queries a 10000000000 + # b 11000000000 + # 1 00000100000 + # 2 00000110000 + # 3 00000111000 + # We make sure to skip the empty cache positions + # and only attend to positions within the same sequence + doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths) + self.prefill_mask = create_block_mask( + doc_mask_mod, 1, None, lengths.sum(), max_tokens + ) + + # This creates the prefilling token ids which look like + # the following for the packed sequence abcdefg1234 + # abcdefg1234 + # 01234560123 + # The token id gives us the position within each sequence + # This is used to compute ROPE and to update the cache + # At each forward pass the current tokens are written to + # offset + tok_id + self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths) + + # This creates the padded token and document ids + # which look like the following for the packed sequence ab123 + # ab---123--- ab---123--- + # padded_doc_id 00000111111 padded_tok_id 01234012345 + # This will later be useful for the attention mask at generation + self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths) + + @torch.compiler.disable + def setup_generation(self, lengths): + # KV Cache offset is set to the start of the padded documents + for module in self.model.modules(): + if isinstance(module, Attention): + module.kv_cache.offset = self.padded_doc_start + # The token ids during generations correspond to the lengths of each doc + # current_tok_id will be incremented during generation + self.current_tok_id = lengths.clone() + # Since we're generating one token per document + # the document id is just an arange + self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device) + + # From here on some methods for generation + def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor): + # Prefilling is done by taking multiple packed sequences and + # doing block diagonal attention on them so they remain independent + self.setup_prefilling(lengths=lengths) + prefill_out = self.model.forward( + tokens, + tok_idx=self.prefill_tok_id, + mask=self.prefill_mask, + attn_impl="flex_attention", + ) + self.setup_generation(lengths=lengths) + return prefill_out + + def generate_next_token(self, current_token): + # Since we're doing generation with multiple sequences at once + # we need to ignore tokens and cache entries from other sequences + # or in the future. + # Example mask : + # keys + # abc--1234-- + # queries c 11100000000 + # 4 00000111100 + + # mask shape : (n_seqs, cache_size) + doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0) + caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0) + mask = doc_mask & caus_mask + out = self.model.forward( + current_token, + tok_idx=self.current_tok_id, # n_seqs + mask=mask, + attn_impl="sdpa", + ) + self.current_tok_id += 1 + return out + + @torch.inference_mode() + def generate(self, prompts): + # Tokenize + prompts = [ + self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts + ] + # Truncate + max_seqlen = ( + self.max_tokens + if not hasattr(self.model, "max_seqlen") + else self.model.max_seqlen + ) + max_prompt_len = self.max_prompt_len or min( + max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len + ) + prompts = [p[-max_prompt_len:] for p in prompts] + # Account for the generation in lengths + padded_lengths = [len(p) + self.max_gen_len for p in prompts] + generation = [] + loglikelihood = [] + greedy = [] + it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths) + if self.show_progress: + it = tqdm(it) + for batch in it: + n_seqs = len(batch) + generated_tokens = [[] for _ in range(n_seqs)] + is_done = [False for _ in range(n_seqs)] + packed_batch, lengths = pack_prompts(batch) + packed_batch, lengths = packed_batch.cuda(), lengths.cuda() + n_seqs = lengths.size(0) + + # Prefilling cache + prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths) + # Selecting last token in each prompt + all_tokens = sample_tokens( + prompt_logits, self.temperature, self.top_p, self.top_k + ) + start_token = all_tokens[:, lengths.cumsum(0) - 1] + + for seq_id, tok in enumerate(start_token.squeeze(0).tolist()): + generated_tokens[seq_id].append(tok) + + current_token = start_token + for i in range(1, self.max_gen_len): + + next_logits = self.generate_next_token(current_token) + next_token = sample_tokens( + next_logits.clone(), self.temperature, self.top_p, self.top_k + ) + + for seq_id, tok in enumerate(next_token.squeeze(0).tolist()): + if not is_done[seq_id]: + generated_tokens[seq_id].append(tok) + current_end_str = self.tokenizer.decode( + generated_tokens[seq_id][-self.max_until_size :] + ) + contains_end_string = any( + [e in current_end_str for e in self.until] + ) + is_done[seq_id] = ( + contains_end_string or tok == self.tokenizer.eos_id + ) + if all(is_done): + break + + current_token = next_token + + generation.extend([self.tokenizer.decode(g) for g in generated_tokens]) + + for p, logit in zip( + batch, prompt_logits.squeeze(0).split(lengths.tolist()) + ): + x = logit[:-1] + y = torch.tensor(p[1:], device=x.device) + loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu()) + greedy.append((x.argmax(dim=-1) == y).cpu()) + + return generation, loglikelihood, greedy + + +def load_consolidated_model_and_tokenizer( + consolidated_path, + model_cls=LMTransformer, + model_args_cls=LMTransformerArgs, +): + ckpt_path = Path(consolidated_path) + config = ckpt_path / "params.json" + config = OmegaConf.load(config) + + param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ + config.distributed.model_dtype + ] + model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) + tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) + model = model_cls(model_args) + st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + model.load_state_dict(st_dict["model"]) + model = model.cuda().eval() + for param in model.parameters(): + param.data = param.data.to(dtype=param_dtype) + return model, tokenizer, config + + +def main(): + # Load CLI arguments (overrides) and combine with a YAML config + cfg = OmegaConf.from_cli() + gen_cfg = dataclass_from_dict( + PackedCausalTransformerGeneratorArgs, cfg, strict=False + ) + print(cfg) + + model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt) + + generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer) + + # Allow multiple prompts + prompts = [] + while True: + prompt = input("Enter a prompt (or press enter to finish): ") + if not prompt: + break + prompts.append(prompt) + + # Start generation + start_time = time.time() + generation, loglikelihood, greedy = generator.generate(prompts) + end_time = time.time() + + # Calculate tokens per second + total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation) + tokens_per_second = total_tokens / (end_time - start_time) + + # Display the results + for i, gen in enumerate(generation): + print(f"\nPrompt {i+1}: {prompts[i]}") + print(f"Generated Text: {gen}") + + print(f"\nTokens per second: {tokens_per_second:.2f}") + + +if __name__ == "__main__": + main() diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py new file mode 100644 index 0000000..bdb47da --- /dev/null +++ b/apps/main/lingua_train.py @@ -0,0 +1,654 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import gc +import logging +import os +import sys +from contextlib import ExitStack +from copy import deepcopy +from dataclasses import asdict, dataclass, field +from pathlib import Path +from timeit import default_timer as timer +from typing import Any, Dict, Optional + +import torch +import torch.distributed +import wandb +import xformers.profiler +from lingua.args import dataclass_from_dict, dump_config, flatten_dict +from lingua.data import ( + DataArgs, + PackTokensState, + build_dataloader_from_args, + init_dataloader_state_from_args, +) +from lingua.tokenizers.build_tokenizer import TokenizerArgs +from omegaconf import OmegaConf +from pydantic import BaseModel +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim import lr_scheduler + +from bytelatent.checkpoint import ( + CheckpointArgs, + CheckpointManager, + load_from_checkpoint, +) +from bytelatent.distributed import ( + DistributedArgs, + EnvironmentArgs, + check_model_value_range, + clean_env, + dist_mean_dict, + get_device_mesh, + get_is_master, + get_world_size, + init_signal_handler, + parallelize_model, + requeue_slurm_job, + setup_env, + setup_torch_distributed, +) +from bytelatent.logger import init_logger +from bytelatent.metrics import ( + GPUMemoryMonitor, + LoggingArgs, + MetricLogger, + get_num_params, +) +from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.probe import AutoProbeD +from bytelatent.profiling import ProfilerArgs, maybe_run_profiler +from bytelatent.stool import StoolArgs, launch_job +from bytelatent.transformer import ( + LMTransformer, + LMTransformerArgs, + build_fsdp_grouping_plan, + get_no_recompute_ops, + get_num_flop_per_token, + tp_parallelize, +) + +logger = logging.getLogger() + + +class TrainArgs(BaseModel): + name: str = "lingua" + dump_dir: str = "" + + seed: int = 42 + + # Number of gradient accumulation steps + # Total batch size is batch_size*grad_acc_steps + grad_acc_steps: int = 1 + + gc_collect_freq: int = 1000 + probe_freq: int | None = None + + # Nb optimizer steps to take + steps: int = 1000 + + data: DataArgs + optim: OptimArgs + model: LMTransformerArgs + distributed: DistributedArgs + env: EnvironmentArgs + + checkpoint: CheckpointArgs + profiling: ProfilerArgs + logging: LoggingArgs + + # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus + async_eval_gpus: int | None = None + eval: Any | None = None + + +@dataclass +class TrainState(Stateful): + step: int # Nb of steps taken by the optimizer + acc_step: int # Nb of accumulation steps done since last optimizer step + scheduler: lr_scheduler.LambdaLR + data_loader_state: PackTokensState + + def state_dict(self) -> Dict[str, Any]: + return { + "step": self.step, + "acc_step": self.acc_step, + "data_loader_state": self.data_loader_state, + "scheduler": self.scheduler.state_dict(), + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.acc_step = state_dict["acc_step"] + self.data_loader_state = PackTokensState(**state_dict["data_loader_state"]) + self.scheduler.load_state_dict(state_dict["scheduler"]) + + +def validate_train_args(args: TrainArgs, output_size: int): + if args.model.vocab_size < 0: + logger.info(f"Setting model output size to {args.model.vocab_size}") + args.model.vocab_size = output_size + assert ( + args.model.vocab_size == output_size + ), "Vocab size should be the same as output size" + + assert args.dump_dir, "Dump dir not set" + + if args.checkpoint.path is None: + logger.info(f"Setting checkpoint path to {args.checkpoint.path}") + args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + + for source in args.data.sources: + data_path = os.path.join(args.data.root_dir, source) + assert os.path.exists(data_path), f"{data_path} doesn't exist" + + if ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + != get_world_size() + ): + assert get_world_size() % args.distributed.dp_shard == 0 + args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + + assert args.distributed.dp_replicate % args.distributed.tp_size == 0 + args.distributed.dp_replicate = ( + args.distributed.dp_replicate // args.distributed.tp_size + ) + + logger.warning( + f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" + ) + assert ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + == get_world_size() + ) + + if args.distributed.fsdp_type == "no_shard": + assert ( + args.distributed.dp_shard == 1 + and args.distributed.dp_replicate == get_world_size() + ) + + args.model.max_seqlen = args.data.seq_len + + if args.distributed.tp_size == 1: + logger.warning( + "Tensor parallelism has not been tested for a while, use at your own risk" + ) + + assert ( + args.probe_freq != args.profiling.mem_steps + ), "Don't profile during probe step" + assert ( + args.probe_freq != args.profiling.profile_steps + ), "Don't profile during probe step" + if args.logging.wandb is not None: + args.logging.wandb.name = args.name + + if args.probe_freq is not None: + assert ( + args.distributed.tp_size == 1 + ), "Probing not supported with tensor parallelism" + assert ( + args.distributed.selective_activation_checkpointing is False + ), "Probing not supported with selective activation checkpointing" + + +preemption_flag = dict(flag=False) + + +def set_preemption_flag(signum, frame): + logger.warning("Signal handler called with signal " + str(signum)) + logger.warning("Preemption ! checkpointing asap and exiting.") + preemption_flag["flag"] = True + + +def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): + test = train_state.step % freq == 0 + if acc_step is not None: + test = test and (train_state.acc_step == acc_step) + elif acc_freq is not None: + test = test and ((train_state.acc_step % acc_freq) == 0) + return test + + +def train(args: TrainArgs): + with ExitStack() as context_stack: + tokenizer_args = TokenizerArgs( + name=args.data.name, + init_kwargs=args.data.tokenizer.init_kwargs, + ) + tokenizer = tokenizer_args.build() + validate_train_args( + args, + tokenizer.n_words, + ) + if get_is_master(): + os.makedirs(args.dump_dir, exist_ok=True) + dump_config(args, Path(args.dump_dir) / "config.yaml") + init_logger(Path(args.dump_dir) / "train.log") + init_signal_handler(set_preemption_flag) # For handling preemption signals. + setup_env(args.env) + setup_torch_distributed(args.distributed) + world_mesh = get_device_mesh(args.distributed) + logger.info(f"Starting job: {args.name}") + + # build dataloader + # need dp world size and rank + dp_mesh = world_mesh["dp_replicate"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + if args.distributed.dp_shard > 1: + dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() + dp_degree *= world_mesh["dp_shard"].size() + + logger.info(f"Running on dp rank : {dp_rank}") + logger.info(f"Running on dp size : {dp_degree}") + + torch.manual_seed(args.seed) + logger.info("Building model") + + # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory + with torch.device("meta"): + model = LMTransformer(args.model) + logger.info("Model is built !") + + model_param_count = get_num_params(model) + + model = parallelize_model( + model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + tp_parallelize=tp_parallelize, + no_recompute_ops=get_no_recompute_ops(), + ) + + # Once we shard the model on different gpus we can actually initialize the model + # First we create empty tensors of the correct shapes + model = model.to_empty(device="cuda") + # Then we init the model. Please make sure this function initializes *ALL* parameters + # and buffers, otherwise you will have random values in the unitialized tensors + # which will silently fail (give nan gradients for example) + + if args.checkpoint.init_ckpt_path: + logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + load_from_checkpoint( + args.checkpoint.init_ckpt_path, model, model_key="model" + ) # Put model_key="" if its directly the model checkpoint + model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded + else: + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(args.model.seed) + model.init_weights() + check_model_value_range(model, range=10.0, std=1.0) + + # log model size + + logger.info(f"Model size: {model_param_count:,} total parameters") + + gpu_memory_monitor = GPUMemoryMonitor("cuda") + logger.info( + f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " + f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" + ) + logger.info(f"GPU memory usage: {gpu_memory_monitor}") + + # build optimizer after apply parallelisms to the model + optimizer, scheduler = build_optimizer(model, args.optim, args.steps) + data_loader_state = init_dataloader_state_from_args( + args.data, dp_rank, dp_degree + ) + + train_state = TrainState( + step=0, + acc_step=0, + data_loader_state=data_loader_state, + scheduler=scheduler, + ) + + checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) + checkpoint.load(model, optimizer, train_state, world_mesh) + # Either load from latest checkpoint or start from scratch + if args.probe_freq is not None: + if get_is_master(): + os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + torch.distributed.barrier() + probe = AutoProbeD( + model, + ( + Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + if (dp_rank % 128 == 0) + else None + ), + ) + probe_mod = model._orig_mod if args.distributed.compile else model + + gc.disable() + + # train loop + model.train() + metric_logger = context_stack.enter_context( + MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + ) + data_loader = context_stack.enter_context( + build_dataloader_from_args( + args.data, + state=train_state.data_loader_state, + ) + ) + torch_profiler = context_stack.enter_context( + maybe_run_profiler(args.dump_dir, model, args.profiling) + ) + + nwords_since_last_log = 0 + time_last_log = timer() + gc.collect() + while train_state.step < args.steps: + # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 + train_state.acc_step += 1 + train_state.acc_step = train_state.acc_step % args.grad_acc_steps + + # get batch + curr_lr = float(optimizer.param_groups[0]["lr"]) + data_load_start = timer() + batch, train_state.data_loader_state = next(data_loader) + batch = torch.tensor( + batch, + dtype=torch.long, + ) + + if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): + logger.info("garbage collection") + # we do garbage collection manually otherwise different processes + # run the GC at different times so they slow down the whole pipeline + gc.collect() + + input_ids = batch[:, :, 0].cuda() + labels = batch[:, :, 1].cuda() + data_load_time = round(timer() - data_load_start, 4) + nwords_since_last_log += input_ids.numel() + + bsz, seqlen = labels.shape + + # forward + start_timer = torch.cuda.Event(enable_timing=True) + end_timer = torch.cuda.Event(enable_timing=True) + start_timer.record() + + # This is an automatic probe that will compute statistics + # of all linears' inputs, weights and outputs + # along with attention logits and entropy + # both in forward and backward pass + if (args.probe_freq is not None) and every_n_steps( + train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps + ): + # Here we do a fake forward and backward pass on a smaller + # batch size to avoid OOM + # This assumes the model has no stateful layers (batch norm..) + assert ( + next(probe_mod.parameters()).grad is None + ), "Can't probe model if grads are not reset" + + with probe: + probe.metadata = { + "it": train_state.step, + "global_step": train_state.step, + "loop": "lingua", + } + # Non compiled model uses roughly 2x memory in our exps + # So we divide bsz by 2 or seqlen by 2 + probe_bsz = max(1, bsz // 2) + probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2) + probe_loss = probe_mod( + input_ids[:probe_bsz, :probe_seq], + labels[:probe_bsz, :probe_seq], + ) + probe_loss.backward() + # We zero grads to cancel this fake step + optimizer.zero_grad() + + assert ( + next(probe_mod.parameters()).grad is None + ), "Probe model shouldn't have grads at this point" + loss = model(input_ids, labels) + + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = loss / args.grad_acc_steps + # backward on scaled loss to create scaled gradients + loss.backward() + # For logging we undo that scaling + loss = loss.detach() * args.grad_acc_steps + + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + + grad_norm = ( + grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm + ).item() + + # optimizer step + if train_state.acc_step == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + train_state.step += 1 + + # updates the scale for next iteration + # training iteration complete + end_timer.record() + + torch.cuda.synchronize() + + curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) + + # if profiler is active + if torch_profiler: + xformers.profiler.step() + + # log metrics + if every_n_steps( + train_state, + args.logging.freq, + acc_step=None if args.logging.acc_freq else 0, + acc_freq=args.logging.acc_freq, + ): + time_delta = timer() - time_last_log + wps = nwords_since_last_log / (time_delta * args.distributed.tp_size) + + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + + total_acc_steps = ( + args.grad_acc_steps * train_state.step + train_state.acc_step + ) + tokens_per_gpu = ( + total_acc_steps * args.data.batch_size * args.data.seq_len + ) + total_tokens = dp_degree * tokens_per_gpu + # This is an estimate and the correct values may change + # if you change the architecture + # Use xformer's analyze profile trace to get actual measurement + FLOPS = ( + get_num_flop_per_token( + model_param_count - args.model.vocab_size * args.model.dim, + args.model.n_layers, + args.model.dim, + args.data.seq_len, + ) + * wps + ) + metrics = flatten_dict( + { + "global_step": train_state.step, + "acc_step": train_state.acc_step, + "speed": { + "wps": wps, + "FLOPS": FLOPS, + "curr_iter_time": curr_iter_time, + "data_load_time": data_load_time, + }, + "optim": { + "grad_norm": grad_norm, + "lr": curr_lr, + "total_tokens": total_tokens, + }, + "memory": gpu_mem_stats._asdict(), + }, + sep="/", + ) + + to_sync = {} + to_sync["loss/out"] = loss.item() + metrics.update(dist_mean_dict(to_sync)) + + if get_is_master(): + metric_logger.log(metrics) + + gpu_memory_monitor.reset_peak_stats() + nwords_since_last_log = 0 + time_last_log = timer() + logger.info( + f"step: {train_state.step}" + f" acc: {train_state.acc_step}" + f" loss: {round(loss.item(),4):>7}" + f" grad: {grad_norm:.2e}" + f" flops: {FLOPS:.2e}" + f" wps: {wps:.2e}" + f" iter: {curr_iter_time:>7}" + f" data: {data_load_time:>5}" + f" lr: {curr_lr:.2e}" + f" mem: {gpu_mem_stats.max_active_pct:.0f}%" + f" pow: {gpu_mem_stats.power_draw/1000} W" + ) + + saved = False + if every_n_steps( + train_state, args.checkpoint.dump.every, acc_step=0 + ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + saved = checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + + if args.eval is not None and every_n_steps( + train_state, args.checkpoint.eval.every, acc_step=0 + ): + from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + + eval_args = dataclass_from_dict(EvalArgs, args.eval) + + eval_args.global_step = train_state.step + eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) + eval_args.dump_dir = str( + os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), + ) + ) + eval_args.metric_log_dir = args.dump_dir + if args.async_eval_gpus is None: + launch_eval(eval_args) + elif get_is_master(): + if wandb.run is not None and args.logging.wandb is not None: + eval_args.wandb = deepcopy(args.logging.wandb) + assert args.async_eval_gpus > 0 + logger.info(f"Launching evals on {args.async_eval_gpus} gpus") + with clean_env(): + launch_job( + StoolArgs( + asdict(eval_args), + script="apps.main.eval", + copy_code=False, + nodes=args.async_eval_gpus // 8, + qos="lowest", + ) + ) + + if preemption_flag["flag"]: + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + requeue_slurm_job() + sys.exit(0) + + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + gc.collect() + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate TrainArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call train.py with train.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in TrainArgs dataclass. + """ + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.structured(TrainArgs()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_object(cfg) + + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/blt-figure.jpg b/blt-figure.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e27edc61ae21f5f59ac69df40db270ee894fd5b8 GIT binary patch literal 54944 zcmdqIWk4L=vMAbkfqe|7yI zI6hMWfJPwzU?TpbPbdxmzVHA5QR;v6u|EL-VtxSV`h3>|klaIaMrX%BV+PPk&@f2Q z?gjux)Kp`k{nh@BEi`ltOe}1id-rkiPy-r?0dzDB40KEkEG$fv!lDJD?gN-4Sfmf1 z$YDRwdV|C4Mkeqs>FYg~XSKcL+LKULK@0cL`?wU8RMa$V?2kA&xrBs8M4yU@%Rg68 zR8m$^)p@0>r;lRM(#qP#*3RC+!_&*#2juJb{zF)JL}XNSa>~clwDeCInR)pIg+;|B zrDb*X4UJ9BEv;>R{R4wT!z15Dr>19S=jIoFErKB%n_JsEzjyawC#PrU7nks>>pya# z0T_Rah5GzkvVW3`1SJtC!zb9Ja#}cV+#WCsyt_yCEa_`) z?|l|QZ78{g`y?&}s}PtC_D8h8B>V3P7W)54vVRxsKjc~hGIqJM99C?seo9GIAxIH-SOy!&{>|J&_u1?8St?iK-j z3^bH7VUPeaz%3##nggg7gIB{C)!`cJ{@rwY6IT~=%@h=aFD8w@INprBo!rc{G1!P> zc>VK0Ru}W4TCqxLSQ>v?G8jA$U9A)zY+&CWg|Rof(ZNjYFQ1qhkHpFY+ICv4m~^jW zHL*qe5~BQ+2D`D0;Y;$7{f<(DD>V#nMemKvv98p?#OqgC?s90JsO#x-;c~92L_UbI8+f zEW~rD%K*23OgGLRi)Qq~aL1Lq!p97sL9Y9{c!OMs$xzQ)&i}NW64mkHZW9VYw zYWh02Vtr%#$*cUw51E^_6Ng$b9+D-sHY@-7-cCnz9t~`Gj??YZG{>E<0Za1kZ4nWN z6;)NN$vA4>YIlI>cM?^_`8S){J)R@<@|JZ~7Snxx)(J6<{4~ZL<$>i!M!%LH2K0Ti zIca!7NlhrZgZV!6$uM=IO!g3`i!Is!P&)>5=LhmTS(;h9n~)t zR#)Zq-WYVTJHUnR4~Luw)57MYzjJu9x8m;rIZ73rG`LbcjNCe2f$2zFi^>vHfY2?` z&kxu@*l(=eM@UV7oeeMzcw2d7{dl;E)7Hl;Zp*2*FMgBNts{QPHL~KzvrxzKP>}PR z`0Yi7G|pGz;)%Mc;c}R+a`Xhg#e*^*81BOd3k`Oqg_`VB)LTSFrW9YPkR?3mfiG8d z3%LuN5Z)(j@lrHbtgUOBuBbe_VLdrKv{cT2k+P>svIpZv6c@H>KP_^iH(NHw6FDr#QE*L5x*vqeUMM#Z<%Jcfddz1hO~3b&(fgf=e zHZ%&78)WM#PDrP(JXE%S`wG5yWh+kL-BO2%mW zz!2}~=b$a!WKktOB_++#9+c6OcbeP*;Q_*haJ3gHDwDHwtH$k1Q}oS@5}K`PN0zzr zJQcDBk0tcRCfXbg)4okfgvPSileKYgZ1H?4Wn%d9nAL~g4&9GC1A4&l0w<=2=y5xrGeq6MtX7eH}{R( z4OfC|UBW&LK~_2OFZ%Q!ei2|IWq&ExBcnK<0C!o>{eTdJ)IydV+iyz0w9fML@>cPQ z8ZBy6nh1~W_tTV}vc|6kG)*X*;9lSYSK~w74L%VF5|{{oc=Mv=vWdhlNd`-MqhUs4 zR(ey!w}#~BIbLz;o<+hIrD?zh`I)tch#AQwkEDJXXzHHEMNu_YWy8snk&{FGqR090 z;JGHZ`2{IWEPh?0Rbt7|4ijK-{uHon1Z?10 zCNcfew6g&!Zg|TI1&bQqu;I9)8pSDclXa`Go^ik0JO}&A-*YrSv>G^rUC8_LH{9=Q zKGhqhrtU5EQkP!Sn|3}c@3^9@`G)^#$XmJ4m>Sbls+VDwH?=^I!quQRmv2=#&Lj3v z?%g!Tm$NT;LRty{nUiUFTN)n2ZhEfV$YJt{0%2GnFRZvk-0w>DXB338xE}7xLNy> zpY*O#!Q%FB_-`6?u>@uRnI6aHYbPy(?!Rbr+*l7^cFQTTU^qk%{saD-5wZUROl@{$ z81Gz#p;(CKB!AEGX5v&$6X$N?AIttHqzt{@&7d?U^suY3WY@^+bWeLlsc7Z9#c2a^ z@ka)K_Y|jR2i_I$4AyJj@=1TL{wQO&Mr_wb2J`8!^cZMP+Re1j(rFUa$xgh#)vO;c zl87k#S=*|q{K$UD?WIKKWV1la9S};qN5emG2T*{Vyef6Hr)BCq%V#2VX861qxgdte z%t;XcR<d7NgWEfI!Mx&~&=qwNLL&r_nj0TuM|AdVV-lK%Z!|xG&uyEx_a?3H^-?tJ+nT%Q_#5$jEhw~?IZvFoQ_1I) zH7lYq-nGg*+2HKwlCP}TBPr(i53C+9FBvYYJqV%7FL*n;@LH~yx);+g8s6P+CJs%A zl@<(;tb?YN6E^c!{>+zzjNBF&_X+KPIAyoHyqDNshV72#t&8IU1pR=;3Yx%`Oq!eH z3e${fQ=iKkW@NGh9CXPV0Pdp}n2N{mpn$<*M?TqsH%Wzf}oto71Y; zyuDdN$qzq5zT$wmPk_N1vZU*jm(~`^ZM9c-fXmlCX@T4!nw)M(HA%*j={&E_LFUod z;$ec$#?3jX5}BTk9G%6}q*nqZ?uwT-?mvQ>;er(oIm~dGbr)jCbY9<*JwJ*JFAs?@ z0-jT&p%XD&^w1wFPXRw%_;(F>>N*jwIee{ul^uqamRe>AiZQjXbS^TmRcL8XgHSwa zC5U>M$|@f#AW$>HG(Ud_pnYj=gwsNg7B5zf)glAL3Rn}zY5Rp11${uyEcq}sG8Wmx$EIvrO=_E$`sAww8qR)!*9!EUw-^45?fYw zr4$QvE8dV04wI(9M;C5cS_}kpLY?wDowGWs^TM>ozFv44{Cs+7OkQ!t8JULYln>SF ziCaw80m%r6Xk5Uf_3;D-mYBh|#;#Z|0$L@1OArJr+am4U>WSdGbl)hWy@U_nfiv|v z(XSNE#yLrP2!W*rFK^ssq?Z`)&cb&(Sj4MSW*b0-+L}pxO9cmW;sQ5X`kS$b#0GfQ z1!2B!PS9KFAa0iBC#-U0G3~-uKvanZ)2HciUg#s{O3=xp-s4lVD4q;GccBnlA zW5^215IqD;9O7a3l}4--JZ)vwI6Qd(uWu8)jMm$JGKcqlN*4q8TWA#XJqp!*%Ix$K z+-x~R*X4Z6jf{Urp~Cl!he;D{#Cud9tzT1ta>Im{j&TtdQ0?7Srg=W6bC*L}<8sGa ziyO_RG~q`3nJZBb)#!azRXqi?p5@lXA4FBismrEC+dhd(qgzCll#CwqcJ>F#;Fp&6 zDcr~YGWa;dO^=g2PMe2v!r(4)*!YKzH?ZO!iX76}hjW{aDLR=O2TF%{E|d$_a;#$v zTTys*Uu>BuUgxvEd$z#IaU*7OLpzk-oayquaYgCnr==C61yeQrFWh8F%r<&x?>U)5 zT;&+PM8gTCL=7f8NWdCYi=gI#g!0OvcP!tf+eHc?m%d4|ef&5iWp%f>jM4 zwnVLGJ~CGBXH8NjpPYe?_6U3qMI%2A9$-XGQ*SptLIWqp%Q2Ap_^guI*rn-d$jhY0 zDT(LfAz)gxrj&-H&sNq-?ylcF_WT(SHlzyzNJ@7X;|qPus?sa>yqiwVmsNb+nN+?r zOP*4s$!66QDcaxyKY=QbjkSxa{0$d$xHG+r_SD(;5ikFkPo7KP-qQFw4c7IUIeN!C zm9PacC9+U_=O%yFwpPz}+f=z8=%`=A%a(c>DQxew%`uuXxN}JQ+<*Xaxf^V8r;L!ykHV0+Tv3GRuVRi-c~2J6eJ!xs1QjDJu3wc!n$LK4N7dxm2fD!0aqe z{l12t*FA}30j>4REYai?77Z5J`M?xP_bA0<1)t%g+HI;ph$LC#-ajuk5ecaHE4oc9O z+7W3^%XHeDc=6RRKkcbuQjhDXY7Gwly>m=A^M@d%VTK2=h@o-xC6Jm-; zcm-L)WE*1O&*~eXu`YvSmqRm;l1_E+vY(s2F#k3W9BsZRh}@GJ|Db93jY>h!rbk<8 zcSS5UT1CY@Hz#*tF+Jw&m}=i|KDtA>*{Kqg{EHx`y~IEXpTaM7wy{ltyZ|~#6-|{S z_X=NmdGK>RV?C05Z!x|T+TTZWR4dlnJ>wZDy$#Vu&tT77prXl(v z0fN>W!Gxd~^ZIFFM;T(!=P_G1u0hw(4V|`@51+D#qTZy5v25%p>6wIM-ZnxlxMPCf zBe6n{R*h9QCOij|X(Lyu^W7Y&tdhmh+r)EFp_)!FFv&E)aWXVKo+Dm2)1d*&8n~cD zqE_CwLhS^cPEMY;+Q%Xg#t9nD zIx%G;A!_YOF67^@>`q6G3`H;+gkr}+&>oRPR;6IWA0%mEOM_`$yt9W5_WTWp%#Lop zOd%RHF&}IqyD-AuG@?v>oz69w{Or^k1>1CJOHYs_0GBj&s;NyKRSIX7M|v*(s-W?$jLUpBrT%vE z$8!t&$Dy(xnhXuf$X$9q%m^h!20ewg&IwCts7Gi_@`Dp}x?5+AsE-P|!!7BWwf|v> zB0Zx99JcPi?tliF$TR+Xh-YyBDp1}vFaOKrCF6SItTd87XKRY&VWLk9S}eo(73H5( zl~}@_x*XgA5PI4{b23T)N|5ue7q5`L8DrCp3D8-uLARLojr0fhP!fgmuZbttgkd3s zt>Reeb*IwJdC^>iTGf}dza6uyqL3Sb1Rjl*k0tk?!BdJs2?Lbk6zPQy8svpzV+HBk z%Xklk^95A%On*J~;=ByHB%gwp4_s@!50LSNn7aMqsrX2FGBI*s8p_OE&GYksT*+M_ z;2gY$=6EfWaCWPt0R5c9NFmx*__HvO{(8)%Y-#!G5MMb#P}NMuN<||xvi9-of>6PI zq;SK$q#AU48!r%j@dSCjN|B{Jll{WEFsZoN-oheOHY=4CUsI*A?v*auqc9|^Jp*fJ z=TudNizmHe-&339C7r0(9t8aYO3VWF`P|~E;5w?;Abepi=|w+^_?C#yM(--y*&Ts!~^(;hgTAtQ&j2 z%~w19>?G@VMvS0poATN}^pJ&aRTfUb!WwTYCa<1+oSK_rO|jSc7&|)c9Q!89M&hL$ z6QdkFVF*bN{TWf=c$v;EDNd&@JmY;(_aZu5?nNx4g#ue{oTuZVpvmkP(klkWoW0c@ zE)2?g+JXhak>Qsa7uDe#@Fy@uf7-%?egbO`necNd5xa@qw2SSdt@UJ~u&+tQX#$f? zY^_OItfa(gyyrGiV5y+-z8vXkXI^iUeZTNZsXQ0??P1SPk((lvZU+M-Yh)7?h$pa= z{w~(`>4Tl=b9&;vi)vqM_I&hy@*hvi+69!W_7h{N!nxeWElsN;WAbv>^uQdV(8j7r9yGx^>PLcX@g%aa&C*k6B zR`1%|vi%WkU45P|?Z}x@Ok!-?aLES{b1D!QBlVsh3B2x9eNL!c;tB1up<`tzVw)n9 zJv&)oo^N{8eUx7HM3K{ks3pI^kkp zz+T8X@nM_HP~TWcI&5mqw0E32z`s>$ZNHT2KvP3}B=DMR1ga(TirrLkYG z`u0#mxxNw9*38o*ff}*?nX_U z9am|)o*Uozw8Iv>&Y*CuS*-yRtj5!;s} zmY_{yxY)A-)?^&x6Eb8SGUf0Dl}TPtJLyT@qRxD?QPDD|=7!E_`&>R*2`!bL%?;qe zAmAmg6m@p1Jt~nE7RU#lTsh>&cUvAPof(Nfz`+#N>BT%R-!Pd}&)dV$Y!<0*C%r1& z%Fcg8@Yxk@TvC5bS`e;ZzP1ve3?oPnuN3ChFqsjmuKgz7%Tu>ty!5NECsu`-Pd7iY z!o7{9E^htnPCV4MH<(l%q9y{^J5VotBoSXb5e% ze)z)&ir;xFx*QS`i*ywN>DT!Q_ErgICDlDWF$p02#2rJ%-tW+VH5P!fPkN*8r_QkB zZCyI`vWJLw>Of?|&s-So)(H6*U(GiDW^ZYfz)9`{3_o^(-f|N^8Ntr_ico`AM@q^; z5(pQ5eiQw`8p^$EjmAVIo`y!3<9*2W3j_bjn6K?%E&%T~gB3_WfeuEtwoNoN=W=@i78>yZTR=a4}KAkaxPwt=IulUv@f;H~G z{;fqRqC`5MVTa`oz?<&xKI4a{-vOBVYY2w~xLFSU9q=_6RqJ5xx&t(^GYPJ9q9T$E5 z5OWIkybhoZT{Q6bkWOha{)Xd-$qbt`?D*%MF*jUm`ua_-ONM5&A9Yzn3?5v%ss}sI zBQ<`Qzl!i^{MSsJ+yR7$ps6Ea!GZw!iK7k-q0AR;TOEs)gIR{l_7}&bCd$nerz4RQ zLNMPL_CCyLMet`YQ5iZUq2aEdccagS6c;VMk%~oQ>d}U7()vRiy4>_p9?bUItz?Y# zB>8df>YDd^iASDyz<+tJCI5vx1NQ(cxCZ%Zt0jM|P(cfH#{GNg!;VPd0@tVUHi^#b zikm3yavRa>C+1F^=(g8*-}W;p7($ysy0Gh~TU08YNtLrcMfx54c@_kdd9$iOG} zgKf74J3|Nm-Nx{Dwd2xC~+L!;GkaC_8?MXT+(i@ELl zPJD;9i*&4Wr)A1nNvRX5i@$p(>JncF_kpO;VD-aHZrKm|3vO~{UQu7KqrvsPO+UYE ze5vw$sEhiU=!;7~j=ZiA2o36~cZIhu{?0uTotT#}ng*)5;42#|YT$^zLyuu2sRsS@ zOrt)_Xs1=<4BL2?>;i^(<=%a>ax*WOZ4V=~A&Ajo74s(e!uI0Jhrs;N(pBcG=ghk_R-9jLT_E!a>}%LNw{ zt@gC>{@aF6l}5&RCdcpYrDmI1xKBwuH9^H%!s72{)NWKk4yS_*xRU|G#b$ecJ~7Aq zPRUnJ+a!k#+L9x(^4niiXmgpEu|0am_=jC$+vfYFEc?V?Mu;RhGDfam*>udYrt&%= zztJeS+PzdfSyIAD?&Qizx{=;rWP%F(g7>_T#PIiV87& zWqd@}TZR07O;Gg}Jf;B_UTF^9f+8-PY0XCVv^3ZbI58+xie*LeVDrgDnC|E3CG>OB>w;DJv$@>dZqCE#KKpd_;v_}}i!T|uik!9V2G6ToObW(^=Floh^RFPw zXSKOE=*Xdq{TsEBWC>7yZPHgZS@kZ}kAYKGTp=9+eI4x(%AQ;jxDPl!A3RQYTR$lM z7{=a7SWWoV7E^>`jxwo^=;;va0O!`RNfkt8hYl4m2L1GQXT&qI_iL%0_Eg8->7P&A zUQl*EIGWVcuI>ydr3xKvYJOB)TEY;sP84f~15fpmN(e7%G5+*A{?M|$F>SV};Ztfn zgoLX)<3Uo$clZW#mmRG5P-%nlKIr{m`)ep-(KgIW`9P+At1XS*;zE=I&sgb{JE=P* z7N6m1wJDqviY@cO=)?8L#mcTy&bkDRkp5x?8_2ySH+w50)@0=c6U1nTt@BIAQfNYN zE)mFOFy|30D4)KnIuJsjcJqvytzM2dcchF(#h{%*-&=82olyf&a_y z*7gr@t2Xn|x7&{gnd`WQx3Rdal0I{Bu+|>xh#ulV!2wdYT~p@wWf~||#e)}Yk&~1^ zx?YDY9-NAjUA0-#pALFluN@q}sfHU5cTsoE5h3*2x-U9Bcw5Fkmv50hD;)WVK{=M) zim!-cw0+QRqt?l#HRzWXds#HpuQqcoL%igi%EyvXDHN$^w`JPzATr&i#kWjIX3Ja&D!-_NN@uN` zIu7PgHFVA#dG3RpnVMS~S{l?o@v)s)wx}>jQi-~n`1C=Q%lt$Sn#nV!DUD%*dv1#+ z@NoanM)%A%L7~gWBdw%0@{GJKocJ$ADqpPC7e`kkY4>vR5po;5E@$67zX$Sb#Jwe@ zI<#y|wTf3EQ+_c8M1MGvYi{z**A)L0Bn=80)`D4&NA0N5Kaw}4TN=D@ESw1#(|*9* zFGQFjPz%<6tW3QY$_2a4m!=}1Su=Eu>2VguD7Z~I5XHWbI1HD=Si%&*c`a@Jaq}~w zhe>?(jhF+??RI$8khMMRRXGo-O54wd@zZTC!h)yrxuhPnQJrc)%1* zQU#pB>aW97BR=+^pGW7}l2yF&GPa(!7NEf(H#Ibwa;wIzPJMpi4N6`ZeFcy z`I?z8S)EFDI0qc;wzEiwcG7&pBGZ!n9}n85wY>+sIx`f1o#u#KW$vc_QANgs$K3&J z8sMED+m z;}+r8bUIH8rK<30@=MIRUAG-s}laawa$JZwZ zUPAi3cL4qkGy)mIcqQy{2MnWXTH%3k`8zU{yn%~8u9QB z_=1B}M{OK|Ey{p%gNRYlGR7UiUKI>KM7=(D=R%39`G$2A+}EJv-S>ijqul{1`Z5Ub z1i0Oiv=K)KC>t;D4KCx<>=u5kdSgjVJ`? z1|tSMy#Lv(B%ezNO?(Dlv5SED^Q0P1f~^!z(^ec`hS_cXwR*kk>b<&z2k^4{4N^)v z;ul}-ZHuKU2ihvy%rC5lXp7rF^*gL*P>?1ox1JzQsRjc{{?!CFHn@vN8L9^kNSc_- zZ{3<-*uo>K?AoMLN&*x@*ki~<+!UM&zFWKSex?|6&I$g-4{k|w$?_!-YhrzEOsKLO zfOjL@{9c+6)D0GdryLd!`k6LmTuj<-`7+v0@SMpAeml(#I_S446^-%t?xE;2HAi{<1@~zZVk?)-J3Hp8 z_}J9NxA7iOw4^yoIyzzJKRAl3n8;TVEJuHiik9+ZJ3yZeaR-DT|37Jc^y;o%ICntu(ni7!8r#tKT88+yZ^TyzEla8`=ii+nVF7$%U;Rr1c{|mNZAV6kYM2(5riIT?yvvEQ*;7 zHF3Tasa>8_AU~#iV<@mNLUm;>FxN;nT|<>?*@E2JlNm~=T@Lm*XWTIM)b*%&B0cPO zJ3bQ}$M}!CO~3QA-$h+WBpe0hBkXSIhc9KckiyLhbO-nv|M;TP!jVktbpouijbrDv zMPU5KO)g4pX#Ne8)1;CF#)AVN#^qZj_~bt@IXXv+E&NYd(E8ak@DCzF|Bh)XFWwvb z(Lbo<{{z$93fccrcKr-r4DZk?ywy^YKqG%vqBY^Bra%q^-%X=doDP4;*w2p3RV>s- zmC6-6JP3i;!BG`Y8usU2hKyjZTJ-@P=4}HO)`OpDKy0QDg6o(hc=82I6sdg&T$#(# zU&b^_WIyQho2$1?XfQA|SypPXa*1a58aW%Oju7TYuU~=F_P4$Drb#QV1O++S>dtjM zOV3PQnH}vcgOPWxWGZzQC;gmak)?PFc{Q|}1+hKY*su>!DGAU*z1vPhY72M$@F!IA zoEwcDhdzu627Y1X##3VMPV13nTl)M$T0I-J@+7T+*KVV9?ttI>xI-PEUoWydkBq^U zpwsG6$bAL>UJl{>;(|q0hYqm$P|;Q*rnpJxjv1RR!Go&xZ@jCT1RBMw@YNolU_q( z=FI;Fo~u#URg+0w!N<$Ywe9rm&`O2km<0`&Y!8gocme+ktp-QDL(+hnhuSng9P%A^ z$P}3G5z9+X7ifRCcRtBe9^DG9=fIa`B@C^w>pn5<5W52`R=e+j!b}D4JAko!^Fx5r z9Z>6nN~KJR+4ZZB-mfV-sv7@j3rE>mrDeYc1dj-Rug#{gY|(z+b#d1(?&B6e)c^x$ zu5RlqjxCe$s7ns+R95j@)ZF6Yp&|%8*E?YCTQ_{m*LDCMvx{V$5vF6ecKxdHJX?{p zppbUtRl(v6Vj#DFacjQ6(-NYXT?o>i;#O{?Z4tWiTu}vo5OG(lqii$PWBU1xih${z ztYwR$>|)T*=^&hdhciO^#;4#&OVKB+?E$~ySN3SBR|R9w2LrZjQK7}J>fi-#W26uH zB~gUTrWq=yvp=gxuJ)?mcb<^ zCPGaj^9kINesOt3_{iIx@p^??}ehov81(8)`4PQ}7+2 zLHSX;^pUofKT~Y3pi)$T3fwJkAwYJromVGUU|0a>cv%zyb$pqotC`6 z16dI4S9W$A2d9xt=I@dEYSg}=T^Z*=>OtC2Z;AC>?v zelN*VWJe4V>^HJ-5Jy%MX7B8w??LcDi`2!wl95#p?F}NmxvpwMjC~(D4?#MFp>02J zAh-dNu&}C6PF|fvT3t4{>vONZaiNG-1Xlo}g6e*ab zZ06g##=w!%krXZ)3B-hw~A_$N0g0a;QD*`f9Vn#!li# zw5EbubuGU~h|0J`c(+lj_jB#lEZ8BaqluCctmQ(%cg zu}ZOnPIHRglt9nCrEVmpm3k5Cr%mz4+!$|erK`smx;N*A$8XRa8Lu(VEXq_xYF?%2 z)t-%|_f&rn^>#JQ#qZ1Ubhtq&wkZCF#85ETRkIdKUx(unmS5q6Wh%_@; z+oOEYkQrEc`eU4-&@NKYYb&x(1l^g>=TM`qA@~XtWzX~Xy3a_=g?^i#XiIb50So>1 zcR&#@R}NQxEI)B6d+H=(9}z00ONelRUcP`4MM(-;w3T+w8k9c$e>hT`OH(n^|e=md(Yu zpOk;p8V??RGnh#fe@4K3bliVgqd@R8i+ts?yPTUN*}9~l>2l)lA3Zhc_)ktse>x#7 z;r8pt_ZQ&q^Swc|QyVjq@)b(Ov~jYgCh{(nzpD7mG?=*~jr5p_c=)rZXfKnOyhQ_{ z9!RC&D8}>K;^19OU72BX_|>1F!%`r4w+R_0dT`aYz<1>wP#s*jpG$}^wB?e{a`JVw zzpb?O$mTKtO(GG6DES2FbL11T=|Dk z4RoVH?6|^}aaQrS0tC_$L#a~)72`bhSV;@DBtzv4KU$sTAy=r~@V^qS)2r`*v-7X| zPB2R3^E+UX8~F=0KWRT>C6~Ws;X3afoVH3ctbqIng2*+94M4WKQ;kk~^wR&P+OEl1 zprSy%EwX}l2H{c7eEM9$PDc=4&AebHXl9sxKe9GjIgEM}M-{y1bcimTf zUu#q}u~en8^vYK{rb-Q3uL-YqEC6Y*(i(GjY6;CxH*#i;G_5f;i(0;lrL>#2&AxEO zphv1;ytzq0l2M{wASu9iKzB+oZ1hSq_*;>q%muGQ_cgP_EnXqY>w5g{bt$#rUbmGS zxm0#tlv4q){((P0jVR(R4|uAmHLFjY zuoJa0c2o4a@(}s8k~UxBNi=fOLRaJ_g!bL3NZC+QiSrE@s{ysKLWlyjJ9|d4yz;kE z$kgJ|+RN_d&{=+`O-I_8tW=f!(GvHA{>v;wZAH09eJJIc2tLRLabNA}x`3~6AbPp+ ztNiobXX6CKrH^!nzi^^>Y5JWPy(fgrQB{AYqR<8k%Q%N{ia zg`NJ}bHwjcF!CDdX>i}Jqyrw@)t$}o*RU8gUUJG@qRaWy{PBpJ{j15>YW+Zdf#9rnDGME@wgJ3Y41(!<}Gi5cX(EsSUg}eDhyC(B&^QLHYPe;}{X~ z%3BdyDlFD4Wi_m%#azlVijxWpj0arPZb;9m$)hLX8sJH9HObK}UDD6_^8TV*(EB+Qa=B?Ahii&ororAEJw6{`eV8Q)sc%WtDu}Y&l2MRHf^ouFs&hv7HBHGB& zkq}fRamUfq`pw{dwA16rAOO@q$iUJLv#E3ggUcN5lbU>Y!mhVtKS48CeBL~3sp#6G zzU7mutF4;%ZV9f@lS>gm0KU<{11TQ&Wtyd0F^-1Pv!8L;(-$Ot6Wy%;#v<5W%X9SO zhZ?D)&W%3Ar@uOVK;#5;+z-10zTbXS%^Ykayri(BkoY3eS)+8oUlSJT*A z8bYh`2@CBE@+H+%=wdGOlU~w@kwHh+1#<3T$Do~F#z_+d$im;37&xqQwL(`0FVd0+ zUEBzAO6>&{#B^9~P`*pVZXvH$k7E3hi!#Fpcia=BOZoCeXui8i|N94N)R!_U<*RA#FxS!-!>7ldhVmK=RE z%;cP18mW^+1z1vh6|Ll%TV%QSxj1SV#rAX2UWh9u6~G^9HpTxON3A*^oRP!E4dU53 zrG=Oe^}yN&tPi*;q+bVCFerc)>VyJBl5U=YQWRzyR?L1OBs{$GN9z}b55&h#`o%e4 zrE2f%rUne+d$-cQ+a&g_u66tj6OBP|ueYO>ftvLir+5tx_1jV{2zT9ghl zc|ESd|3yLj9^IyS;e9%GDvs#aurB!~nNDG|W_UxE>)`F1xOfLRTrs=~%kngqth-O^ z>Sl&ofbbXxksP;c{Vw;GHJopSd0o3wZ4-`J`Qm&ZmE88RxJ3bCI0&v=|ADm*cs8WH z(@12$*gmTvTgd-$*y5hnLXHt-E7`RieLyR)2L$qvP3y`ThtEANI z>X4go3CEi|(Z6l=W;*KY*j`(oAJT{n;_8Mq$DC;pg4`N8Af&m1y3yBLVAG_|!LHmg zJY6@22R}}&63g-)G?EcDc1VG92|;S=>a=}}rcFtfjNDiSoS3h3iX8v}|D2*`4m+Y; z)P$MGnBhQgbHh5=e8B4bsx^AVOZ^3_VWtPB3n@Wr677Gq`#|jpkBa14oHRZhwGWZ7 z@J$7%Y^UFc^4iM1|BpJB+Y1be=_L!YH5Jne2r@U#vH1Y98Erhbeq597#bX;&cxcFTS7x&t}a z+&+6wAU?M&$>@;Ww|b#=6JzWjdDJz~EFO2bljI~cF))uu%B@PGxb@J$yPGrs%frbZ zqVB**V%KCw4+|?>+V^vbJd~ca6dWcFQKO4;dpnsPa&JqCuYv^k%1SK7LksSKDN-2> zdjMf?^>yeD_Cgr(w%uS)Cr|dE5{8P>1@kyXYsgU}g{yYxR$4B^Nl&o(kjClbDsv~k z1oe9So~FmV2W<622sG}IDzpieD_raj^D-qGybh0hoHIrYe~1XLwxiKx$;4XxXHvlDQeYE(6y-gkP4Qqc#|u(oUG}Yonb%da1eg9tiGKVUl*7 z?6w}KT{T>jgDf%P$YCMsRoJr~+75?u`nEE*f7nKV5K1h=*%>hScj7s;hA{(cfQ0d9W4jFsKxZct? za1!We)n@VHh>iIw;v++^44c_4g4^}WV+o8Sg8OJEwlQxdAC<$)HRsM5sX(eX1C&Vp zCFV)Xxx$vzuX3c4O>H(sYRW9S>f2IV*2tk~q%0bofZTY`zLCN!dB28@sAN>?6=d&vaw9d+HEgiAe7WgWe(v(pPGMi6B zd<<1ma_0?ryq8|S_xfi>V4z!}@q*oy*-F1-q2V~*vblLlym3~0OnFv@%^NB%0l7`< zNL{TnYcd1$XpSP_CAMRZk3Chpm+*w>ij3+Gx9V#${qZJqa$!^d4Tfy+82wiUGLZ07 z{AFss;|e2^hw_UYnetRD$1Oo$U**(qBt&~vVG`D=i0^Yfmdwg%L$oCfhXIv`=rRz9Gypi6uUE3kx6>V{%395G~`V{(LB7mqpFf zA-k9!cFrrRJ+a=dWAfEU*fj7F|nzZ;(gOrdv~Egeu~-MKt3Od;Do4q^aat* z9G)!*9hy1u8z*xV)MW+=h<>xyw~14$K0*wlL6tNW+@2t({5tLc0)Fgrl-EH88pa@_ zY4hlFa8rv~iIpALKk}>LvyiFUCE8Spm&#^Xlo*Fb+}pew7ip+UwHYe=+zQ>!_p$xD zvO8xKUr7D#b=WzXJl=u8eOfkX@2kUMGg*+_@ij%KJG#c=}>5j)M3fB;sva!-9z-SBC*RJ}o z_3}xg$IEp;U2>RYE>(b&ErfeF%cvD}*rCgw%=>6^)ZWpGKdhC%^vW)BpM>=R3+Xzi z_Dk+rga&Msn~2a{?dz`lSheipjEq%c5$~E}+^R~UM7V`a8+XIjhtg?D?ssPdsEN;Q z7>_tJj^TYY;GLtsDRLcK9M`^9R_E&`SZ|jISuFo}#Lr?wUwc2q{gQHg^JK!E*3<$5ND$VRcb6emCifXqDgGJM4d)&JA z?7tp9muJ7=rfovQElA{^;_myusC&<-sG7BHuvG*RkeqWyNrL1k8AL>K&Y7lZ$*G|U z0+NFQ3Ia+-at@NC2uKELa+7m6v5C%pp7)$@&U?=LW9D14X3hNA>|T3qYKN-Yb=Q4` zl67pqAmWkTvqGKUX7ZNAf6z&OIsIy&<|lw8OJi1$OHl~Xkki#djA(cIo@kKiCH;iO z4&1R$DKU2-z=7^DDKMq_#3K0kA&%MLiAn@87SML3ZJJb0w9S-OA}%wplP#R_3AN5L zBK!aXU}Y(h=R+LwWdNQ&GLm70w|)cC4zva{UQ-`O<0Q%jeh(6gaxH}WY4#?(Yp&=X zq{q3!Q%V3PDhl8DK4IxB+(a+3o;e+48QB{azR8jF*25)D-utqz{j06aLMmuf7L7(N z-S)w6dYtt}B0^JCdruw5b8AghXmhR=F{j8spm+TpX>J&J7sU*X=Ve@LZ)MZ;rfq#X zmz}cUO$aM=sZXn=NddMT-od-ei24RFkh3W;hdmON($=9H(j3ftg&OgTma>Z}39WvG zVeM<}PO@izA%mAD1v67A^)j(eQ*YvXMl(K;I;1&eW8IQ&8-R)8rUJ-2YL*(c&ejeSrN+CevDQKPaLp#L7luGYSiM*x!I<=nqfp`-@aIilikBvIlOs1RkBS}3Soa({nlGf`IVMNmQf*$7^7Zb z#`h(Qyy|<~t2BeEz~(LZTtK#<#kGH{JXdana0D4~4*rv(4v;rUXRBQ2`CooC|F+$s zJ;=`YNEvJRdn_`x^mB(%W1OJK{bzj$c;sHGFfWfOnaFf1nkU)Rg~RA!e^Z`DHmfU^ zt7}tvj26l~qulq&YDXJeW8*YW1;lt9369Wm6Kw3XcT6FZg?)=gc% zH%(`BR}DI%>E-n1_WDejJl3yQ3+$1pY4r0`HeLX}6iF`|!PXdTnNUYuMHP%5;q9*&@irlM!2uAA0AC;vul`-H@^1( z!Gv>SKhMyFsrmC!KGXh3Kjxv@R9V5Jga#As_9Hk>;G$XMU78=`P|Hck@wgmQw=hq zb<0llDu1zrbm%G3c){y_1XkZ?+3Z353Zuv*C}Ho6?u9Z?OZ@>o-H@N|d@x%Zmen#X zeBKLu3z{`Uh1vnU&e_Sh^*kC<`~4kLUBetlmN4}=M_~1hT;P0NL&!j=hwhcUH z6`qxXHh-$I`Yks!HW~t)$O$+Tb_YSo>e#jM%rz7{?ag$eic|iM4Bf;jSDx|$Lxaqh z7$1_SRtYs+RKMI!NghyrAq+ydchlqa&(t^gPXZ28BW>uvitQxgSSzB-(!EJ;J^Pb_j+S{$MR_SX5okA{3J zC{Q@Phg1M>nIgPk@n@PpWz^Tn{DwrMvRV@e6%}GO-McZ#`?I9_!@F+`#v8nNIlE;9 zF%8H#R-|I6i5X2-AB1n(4?(4mtFPeZ8BKl1?a98N`|Bsm<&O8hWG}x;^5M&cdLgB) z%_#LIZ2^8*h2WISTDsBg=uacwx3@-a=dgp%q^});pg$saRg(cz(=+X4o5dA&#mDva z^UuDSL1K+JuSTl|-V_r(xziEluC~2#410mb4!3dhT83EFWk2*a=z6`YdVV&-dzSYy zEs4n~>g_l35W|7>yGD%TjV7BYLc`ODYAfg-Zgqm>QJil2pIfH>&B5)99- z#5b}#3wKh6^0%2Y17o)Qj3-_Ot}q!8XQru7)=}9=Rz3ciKax(n$y6{YM$qrHl1&8D zyj@}-pMWW0-++69m?BIvG881|x~=aqMM;f{i#^0eeDlQtX}yu@%V{JGZ{wJ}wpu2j z?J@q@*Kn@&GJy81+$Y^sCIQo6wLZ7>!|h-#B{I-IwmVJWKOlzf+kmO)ivRVi z^8n}1^i42k7{3~`v|OSbi{z(%0SIiK03(ugX>7J!GIpRwz`7q!&*25F3&KmDn!q-&OL?J!JpejWD*L_wE8(8T<;oT+I;!U~nZ|8N70ehnRDs2v-R8|G zFG3P?Oyuhs-n}xa`F<$=>u$>#LaI2XI2n6zF_MQ(asIDhG=z2;Kwv~w5)|6TTLd?Fr z$sKPtYPq*;5;H8SmEGO!nPx2IME;@72{*tHc@>2R8FJ3?NxtVu<6h@vr zxGqbFR9{dEn)knZ?b4hE|0IF#so0R&WE7a6mHNDD0+T`!WESLFIjxk5&_XOX&e&?g zwYN59+g>-m&kAxmHc$Bbc60ZLE*mygB<>}TdD4@1`+NGUM0#EY3ii?Mg-)|au#adH zgXqqYIlxrRO`YWMQVP`3m zZi`T2Lh8=42Bj_ps)E`8^p-WX5gT}_HW>k-rHj@@E|u1}BjVya&NJ_%Xjcd+Xg!db zkNOQo;+jeQAoTsIzT%n!CQ!NY@9lq=$M|>6iI37%J0he!T8F%!vW-7aEk&~9Poq;G z6N?SRsaNT3iPAmN@ThNpB*XA~>pY908rAm+n!m$CvoS3xrjPrZE{x4ez6z&6g}!9` z_OC|Z85qi_f4#N&+%ImVTQ*WOJuu?E&kBSOOwUte{QT^7t^RtvVNdv*j08@Q?P+<= zq51?rVydEyxk*GYYJVuHn~8CY4b zDA*P~eC6*p{=MSAzlZgRY^!`@6|4L?B?kzFx6(r%H-M$zRRpq4)@(H`ga6Cfur2&( zZj1k_=D({b-V-?u(7{)HxTB_6h*^9&+XU_w8co?88lA8^!G)0~UZq z%puKI51Q;_G$;0h0|qZ9ZS%NPsb0(ZSd+_9i$5V8Bga^AyoX%9hIhVf@d#Y)p-G5n z!$T-I<^O&m+Qm2e<_yJ|I6g6u*ACjC$8>VimzbB^-X+^u`Jb*g*W31hOYSH_iZ}&R5!}t-7n#tM)0i&3~=LjN&kS-gfAtP z{C9aG^+%*v?Zy~CTc9QwQTi5_%xxrpKpk)YfYyf@|MO5#{FvsIV751Z`Oq~m>y7vb z==`dIp^*TUc8kUN2gC;iQ2%)leBvD~Tl4tb>U1{y|N4bg&g^9NWWGN;j=80=xqUeJ z2c#$pK+#0m*G;+qain4l!F4-SxDYvGALRC*Pb>7FhyHp2jI$rEDlhj&e(__R5-{BS zH$+Rg@MU@U(MH0*A81!`%>C=h8O=Uj05Lv+z#q_kJ75H%UE#lepm}RLVRDuWyo}P{ z33V3N4^TbR=%a0{l{_Y|hm`t>n_Gq37P%+d|F}a~{}u4>oeFH9X98+t47&0u_( zk8kt*As0{3yT|x0KEjvc7O}D@?TT<3`zDlDoA6fm6+owG$01J zX?NDVdY}SO5=jj9hl5yyfaZD>v{c7=^Oh9|45GM8?qBO#MPaBfJ(L91Sfgl+3|W}i zKYj+`XsZ=+W8e9(#vn((yOls?{Lk+Rx4>x`1E=F!dk8%)7PT&s%4=n?yf}%jk+tR4 zGC5#ICjCI>1L*Rv)1Tm#l$cO4?3~#)L{huu!}0 zUOIBLc<=|bXIcgHbn6z=kyZ1^;t-6{sS=r%xzlg5zikr#vP%4~jg{_hAa|CJgY}TR zh9}|Q79nMigS9MVjWOT*jWzoS_`4XGTH^yWAi}t?ulA$HopiF&)h1?b6K-#8Hf*Uo z<~z8|1F~bxS*Ku+m%(Koiirw7$TKQiDwlYz-vM4_pp$6sFignm#}=IY zxg9|x+L7Afwa|%fFqQ?YxL;wTXS{INvb&l!2&duWHeifm9#dlMgviRsF&nXx$#5$% zH^qfqf9IAIUu$UeW|3z{pp$cW=Vnq)fLRZ@5^UY?O3CSxj`F>4L$|>tb);s)q@~e% zy{u#Okyk^A0#81x`~GnhwW3Y3+ps@lw?$O5N>2&Ba-cK?UM9!p-i!cy52Yh@u8~$U z*vEH)+SL7#0=D)Y?3j>kU4ulJEP6UW5M}Wa6M}rOHVJ+8dB!8Z)|u_od$%W1DX(8W z6?}fmWIK~FA2fQN-eO^kDu{pYtndHpy^SA@x$%P1nloRAOFFKRRQGx{k)OOAQ@oQY zVewW)Mk(B>{r(P3@2^bg9Y@M5IY>UGei_@5Q!(pLg6f}u{e|KNM-s>q+pWnXvjPU6 z?hikJ)@I|)Z(n{$O?S<;PfLD2C>(e5Q?BLRnk1d;RjfPSV`-|p%YjNtew@9#6_id( zgkDezUk9VWhldIiQ2s?`EK2TYop(}fJXq>MPvp2~WM6EY!09m=-RT9{E7JOshtw7T z&Qu{hsby@!oBx$X#4OrI@M(~qE`>|>3(J5T@?#;{+Fx&{0dw1&Uo-x?`lL6h(>vn_ z2I^ONwTe&@^@$XyjNq#8yOYGEO;$o_UQZV2|+VnHh$oh?}4=H&Cy)~j_0Cj%|l0Xoum}* za>ryWZ{IZOdt1My+1|u#^N1#4++gf*&S=`#w<8b%`}+BEh%uU-^b?jcN4VOrWx4yR zKMLPSUq)XbXI7^+a^VhE`512PW{I-CuV70gSW!Kyw#isBHuVG5sEc)JhbvPNb)UNz zKccY18`s8zz9j*|m+erzAk;unKRzeti>qW4s}INW=X}pe_Vo_fc-q=9ts}ni*?{%KZ>yNsb0> zvACnTv+{IP{+qqP|65d39&&9ZN ze*~oDT-EW4b=q;Vi4~6%yalU@y|;4R$afgHP`{#9H3^&-3DLX_b3$?M;{mikL|cL& ztzA^>8t4?S`BG)Q5PCW676a9rl(q+s(>@4DTT-BxO|g-Bp{el_*uU(i1~K&2NGhYG z$wntwq`JuNmV_XnI<~^oPs@24u=gOA5xI7lxyD25tQ<6y&VOV>kuSaDJuJVL7724N zz+DaXwTv)onGV%We48?5rfJ{n507t?f%%VJTOnm5Z*h|j`DExF5w^#}IFWXftWy;g zyogOoUiu;Hz9s9o_kHU3_rK%Tw-s03q_#cyJy^UT6XgN-D%}>Pp4IIDOB%MG+2>p4pK!A_-;XNq&XG3%!GjPFg#Z z6z@%OSC0UiDt?h+jn|zwfdR{Oi`f#EmlB_AaAbmG+`&^j%D(=$v5*Qamvvrk(B#y++TQ_%~?zbe}i}| z{{aC3WIV#IBJNtO0ne*s+dzgyLT|C*%Ruf=9gE4h*a0`2&9tKq)4tib3vqe^f8_?W z$yp>SkCEgWmtZ|1MoN7}?OL2Y^RzxIbz;DwtR^NUy7SgQh(;S|n#pU*_%8hVFmwi`RblyPPT~Co?c#Ez*3t7#=kUf^R9iY> zgpl4hvKP((49L)jn8==bpS~F+dqG-}(GCNB>iDWlsj+w>poLRC{_>JmGoEe4o{sfOwFf%ep=UTs^KOz!qaPQ9b#+$T!Q~@ zvExz=v@`BrpW13GclWJ+{{8tGl=6&^wvmassja#j)cPpH&@Ql_pYDab`vbPo;uq(Q z3#y3zfpm*LLWGSTN;X^yN*E<%n9-jlAVX7K2Vwu@!fV~s)({XyXIh*k&l~pk4yc2A zZ93u;I^is8EL(fJMZIDRk(#tZNdFrnvh$#}J~Z$Z?QVu+`rZ#dYRhDymIUV~DeP2L z%n>p_YA~Xx#cqXBDN(09KYyPHVU=mmt+p}oBYps}c50x;lQBo?)D3bD$2C5@I}fA7 zxOv$XZy#^dmYv4A5)v=)BRCoCOI%O`^;K2iD9G2w1cPTuLg-qn0tF(SG`)v9oZE!6 zF})TPZ7dl3T_k@Gy>r*<%gOSJea{<-@s5tOL90ONich?qpRT79l`_AQ@zcFe-#tl~ z?Ot9CUnNKXO8Zi49BYug^Rx%bm(_>2-Hq#NRpq2nsw3t-m?zi7Mx-XZ4b1bQ5!d%t zFjBlL+_IHU-bbeDbTYQdD+dJuQLna*J|3hhyO4%_#s`7Ed9G!*^osYE5g%LZ1JN2s z6q@t4#R`a>i@@8mzXNVH@KzXaVp3%XR;*{x;U&f!>enUG_>WjC}x zAnqi8?um%(Y~OdYlBrlFu)hVK16VP^N2mPr;XNq`>9_tfW;|^D9dylOP%+&)g0WNc z5ys|abIn9KpBF`jM(aeqGYL=90R4VX!z*@y9Py##>P_+4oh4yd8DAz^W!WI2#RILN zFvfz4Pp?n+r10nME99u6H5&*Zb2qIK=&~2(4qkW_px=;iD-LtU1R{%D*rgDV)9Da9 zNPq`(y}G0DJ)buQj>+Jn+YuTLqNgOJOlzQyyUYHJZ-cPFh|GHY~7n58itR9a}O$< z4eo7~@CW;QRfOB2c>80#2A3WUI-iC1*;f(gK}RFw$Jp2jSv?Z{pIRl_Y6drpT?G!c zu}s-cLf0lZt}HHTgh;jprBe6t)m$pM7_9Uyc#7#(-FSk$vXSr}Lm8GGbB+eIa-0{P zIZlFBeLWxAN$1zEt(V%GJOMs0GtMiq3Rv0hY#7wvjxds4tDqms5^V?ZWds%H47eDJ zR8zPnEhw9LRQs(a;p$YS92unEW628kN(C&DJh`0?wRvrH%mlZVg#6tHADWo7<wnBL}1=getJk?;+xiGoDg;j z{*<+(SC!u7H<|NG+c);ra`6GlJ9?4o7+5QIGocckPBerqo8)7+|-?^!3R7k{)*(wM5^lWH6nIyz?+ znN@wGdI25_;#qR;KjX@}J=Qh3+895X1~=YG@ZSCSRLn;*B}iYgi7D$@@BvP#QP4;b zZ9Q}OY3r+#JPQ3EbO^*nXPNs__cfKMnF>O?lf{Vi@K^%n0(C)0i)%Lh8?P z#Ndmp+dwaZt6#-uvD|lY^?#FzQ8QIN#4Xp}L~iz4P%5mD;#~pV-KO_+RT?C|R)B94IUd2m0=~JyI>2Di+_#a(GnU9r*{-{!-;c>#V zq7Z(`@XBvwL;~6;OYPex9RyJ~8{?_HDi+=s;U!2`oMNFM8mJn7zTd16H;%3Jx}|vJ zZvWo)ss-chf-h~yB3n!Cq);5VE2h}^;oTA#VrMy64t?|Fs2_m%(JnUHtZz{VaqrXdrZ}}W2a>X8WDJS z5Ro;L!MnoitzT^@GdPpY=%AM7*iq4+9YPX(az4^6MO^b1VvM}lCe~8=iUXyu+rZd! zJeYaPhmBhCcLq*5nqP2z0f>nIYa^J0p5mSq@8JEt+x?5LHU9r_o`(U()ts;bzb2bm zjSRkirIg>Fzzi4&XJSkej}3pTGNgS^K*@$0{P) z<8EzI(oXD)?0nx(@6Ttx?pUTZj|ApKbN`b3*Z26>Z|usSiIUNqs!Yo>YIrP%cLyTDWfNiP9mw5STxu51FSLp?}_@k3Otn&`;In23`m@0l6=wmKCGj9}E! z`xU(I1z}X3Il1u6^zI{2?t`@1wGOEZX;y`muf5UTw|wwpIoLmXXzH38h@pv`Mbi6P zzy97axbz#Y>m9DuQGfJECG{XbEz4KhZ`bVbpXCC0&hTLy0|rNT8i0_qBTnX8^#qzp zq{lLX0qG|RmEAjzSNJjo@t!!@roObD25G5?@Wl`-;i=3TqfTx_&j@)*v2r+7xxQej z?w3`pSj+UfmBWX=7G3()M!p6rHDhJ#kVE>+7VoX$9ev@~nyubIkPOR+6s)09O1AQh)j&3m`S+=zi6;w9Y6?Kl+?{b91De3DAkF6|3B|sGxJV@lFMK1?jIKw*ed&%ip7>DullPt82_r4U%~x3 z^)f&?h3t7;K6jF(ojksw+XzK&iDa*1nsMN1y2cY$RU0E$rub*fiuSnQJ?@L) zV7JD9Gi)wV?x{RkQUC&Z8rs{3+#o?8rDvE|e?ygf1a0y#xIM{2C!Sh#qGI(12&=b!#;?LT!6Oz2zs zw{m3v{YiFDeYzRxWT-LSJ&C?5ijy0+)6Eek?<||ETB?t2f{We{f<}YBAHxwKnvJe#HP!O+WykJmh~sxem~)8vx#-RJj%sy(TlbjmvC5JjuQ`3z{>a zUxhMaWLKs$pe4S8+Ga9%gKujaXc>%L`RI*$UU!qUB)9eWJc(9SbP&_CCh5~VkS)>K z3zCL+E5JLI;9IkEK3abV=6_tyq6`qm^rDXYFQ;ZAr38uw`3fCe8=z1R2QSBoh3PwWM;{Q>=7xlk)A> zq0n(IM-i#I8UvKUa)oa5w)s?h3Q$EMSY_%FU69Dqy|8wg1 zf9`XgtaT#Z5KafP{jD^`{4GtzTyM{Y{FfB9fAcTdX#b|t$A1V*VL_gM=}Uo~(f?0Q zIrVh=U>Wqm5C>qC1plpB{og~i|ALv&769`|wJ=6@cM5)DZBSjG&i(bmO9cMM18f1C3^+8-7Lo2XszcY%oTC+`oShFbrV?`t_c} zAJDvVzNxS5s=^e_G=_TESMpZ<4CwXNeSnHRzZnu}WbqjDAIHq78?JiVERwP)F^=3D z*>Q|V+CdEP8D5|`rMa2$ssCUY@w8(`Px;zA24n!-eZT}=coTdLFt_q0>H=9UL}9>Z zCAnc$6EF<|2>36&oG+D21aGCnBQZo;O5|NG|5Jcy(3IGZCwRJjF-`BX_y<%Zt7TeH zL$KeM6aXz)k3_rAo0(EXHG&*L|w+Yy7!bAK1_ZJB56?{QML^(1x=)XX$5zEaXAnd!%APtUd@ibo^7$f&UCqinbXy`-VQwjI{M4xn|(%- zcNh25K-nvxRW?9?GD~JK>_5`SPlpNO_1^dcQkW$I?=S&2A;?9IHWPO`6fu(aN_ zwK>Gy8YLHN+qSey8D}?1KPO*xJJ>XH%J|5rd7-Hh{9Q*Z!ie>#>s_9scO|}&$M>WT z0BoV*-u%l#0yb9nHb$$pAwWZrUSvFAn?$^vtzpu2c4jZglxy#$JhfDJyo{zXs+|h3 zsyT$C@p{b}(t3EIITjsf4Wnj?&1P|xRke*CZZ_N(zl!MBlC-heU)>@2OnL9{1@}g) z&1&mHWc+?DEv=;b0tZC9UtLiCoa7e)DwgLLx@&q2;siR%`yhEI3-@jl`byRL5|Z%F z^;+#quajT?9QDeC&wPJC%L03;^DY}mot9cFcKX)3Xca3dc{$KY2{ur0?TU8IKZ6tY ze^G2>niS^R<7N8$WL5wX8_G*F%%$T$3jI#)5TiywQoCo`kWCMAu5#k@vI0^Sj~5g4Xx1jG!FE z{oq3S_XBVA@A}l#1{rhsTC35j@+}8GQ-aMF!eGslEuu(XU&n1IN61RCa6xzhz-Gcd zEL#ur)%tiodtHugDDvis96F=-af!vkaE>v&OSAy&O~VGPN_bN_Sj5{ltIf~vWVpC~ zk6_krfYArP9B1srp#O}K76x`&***}GrXi4@Qy68NIb?5dOZ(+GzH=ZFLZvwOIlqc) zI9P(A?>lXK90O;d7E~ycWb9h3jj0NCl04$R9k8(=#JpjsXeaVG&AB*mIBuR_dpVF% z$xkKr+6`$!NNR$4)@cmMzb!&|Ugix4zBwDsBl?l_eaLa8`eSH?>YgUQq0cJr(P`t{8DxXJiNlmEJJYhl6KlDgX3~e3P z@Y~mPb#sfpJc^IzDhI3-|56|Q-~61HtM|0YqWHrMd}3OZ3LTsJWafOxo5Z(6SnsOs z;TVH<9V>vTX86L2$zfVovv&;~eI6 zSf49TYjq=Z`noST->7 z94i)=vyvX%5hMUpnz&^K*fW&jLDbvn*-@tu`KjcY9~0%P)yaCyDaMzEsS66kah-G2 zzXZ&WS1KYI39o5-uQ^xBL{$pM6t| z95$X&t>bha@OsiiKccA7Lt!)eWbJ(icFW-{g-Q%e>#9AXjku{Xq}kbc@rGlNaZxps z=Ey?p>aFrB3CLZf$C_Q=bYec*b5bv9Z|Z1ipLUg=BWM3jV^2bVwDh}AOTsp`4qv!cOASGVLr?MUb%>=oK>v4qrXqP7+BD)Yn&uWmB8 zio{e<)VEKy;sxqFLP~svURVf@=j7e26bmb*NX!#hJC%HW=l%St^e)#xLHgs8Em-?9 z;2EPgbIXcR@0DRjg`nEfa#i5Go!{FUvX)zWWL|3Jd6zxeB!kPJ?Wo%IK zBzZGyt?S{FAKcwClD-bT&1auh0yR6Et#?JU=%c_le+6MX#J=p|xBSa1R_a9;o6&dirW)xSNb&;=|l+(1c4mPIs zL^OyTIdtP^V}sQ<_|}Jf=MDCod(eatCxjH+&m7G;CYGm!u418~NC6fSaU7i7$Logf zgiK-PK$!C2yayg>(SkCY!*t8NYE~|ky?)80sV;dzc=&vuehdhiN4h}w^_<6xImKPM z$c3_xZ9YGfTvG6I;R6g{ZCawde@3{l-ggYGKET<)a=>oo z1~d=EFo9JgJ|556yt=8<=c4|0YxaoSIjMVWgb@hMg(M&48?BEXW>Zh}=-Be{ zQ{MBLl%c{zF3Zv!Fp!`LBFfR5Fz*?GK{i{jN73rc^xYsOZ{dF5?3K;nYj7>5p?jL< zdw*QAnTZNdQqpK{CC^*$Cvs$Z-!?fPyLJ(4>8sc%`5sY%8phF-NvzUw?me|dQNDbTi8T=-6xeQEpti0 z$-w6Ic#s7RhSO|J{$_iBKh23f;v$t1Y;yoj9r@ERxJ-{RS;9$3ezGud zIcC<>md=(4T;BQz^g1Y~s>IL*Cw;3OSBiXPms)7}u`pI$;`27s^!l%YIdc*NM6v2` z77g)pnV(Vz>HkdTw&`8GgX)i$=~!fm@Ol3-X}vxfaj@~zFiVlEM-9xT`7`+eCsp~+ zHV4c_!Eh<2aOLbhMYcl;?tvEx;H5N~np*8r%Y!#S=3dDiTlCxAZKCWo!}pwPQB+o# zo1LE>`6M6=d_M?fU$|M=?VEW}H4xF1t`MPj&(%9D8N~Mfy=Z54BxyecZ`-xXc&;ud z$80{QmOvCCA_lnlv0uyty?EbuAGW)E7}cE6O^7Fb)~LH>tlUuQg!mxvW5&(+g%zc@ z7Q$43myPvCbLh@78(y7v$m>PlW8_TODIxJrdj7cC&bi1fWnF;QgQveoiY4wFv~5pS)~D!Gc#vY_MypyxRm3VLyk;Ov ze#@K5$NeBT&Wl!B`4X|#Oa9?9rW1mt!x}^0#{EGM>7_>96>j&OF+mO!gOtUn-ePSB zan&{%!Rt6s`yHwGeUq+llt-(S;F-E>D<+9zLU)#6f+)&srgZ&0zI0q+4Ii!MDgL$k z*k0VmW>^wUmHoqS_=YvI$BewCU{Z#Q)@;QEHvRe#`6+6%J}ug zVj;JRdQ5$CAP3=f?=~4srpGr0FOhuPX8U3!#GyOUBtOHiGCyaLitl9`bZBR~LlE-t zX8;2@6qVMq^iCw3z$JCUqACK4j}htaAyqG$s~vYHkoc*{e!b32Bu}Ci}kCm zDByzv^aTesH%)GZgPSqGy_jjt%~V_(>epZT+?{17ic;>qA8Hre`GbuL67&!zT&w=W zK+VarwLX2QTX?Exv>WoKb3J-doXxgD-%g)}=yS-H01{qze`>lFwYqX$Trl9lQG+z2 zlF2*^bD4;Fw>Dh!QlKWtn{t1q&W*+j6NlpOGX7kB5d_%ezpXH>uOIV_Ox5t(8_R(T zTl=wxQ?GG1`XztEr~90waBs&71zte`JPJe~QbD-K0OG}dI=6)Pj&v|usl`d7tJd*F zIUAMs^W{YUc06Cn?kv2~{zw4hK7)u!jr(b`y|>%+V0&9zy65wZ#64i96G+c?qB(mP zIWQJTh`N3RLYx^5YO0Qqt}<6IXd4$P9g0xPeUh|XsWLy<`1B62)@#iwuzV;!nj^m) z^Ioy}c5aky(t4uyG@14a+vG!nljOXVvzR@kly3R1ID5wx?JEo?YDRssOqg5xasL^t z3pqV|a9~?q|LN6wGObEV*{YpXv1;GM&*kax^R9%uOIry0&@;hKNUl_Z{n@#X$T< zmeO;1Thprh@sF@X+iNL%Pf_8;4HP)q8rHX6PG8lXDhe8E!`aoCCwdQ-geX{idrDth zh8#cS!WP8)wW9oFlb0w|X((==?GHzHh8ZA$|%$J6aY*R(zKC#*Id~Xy7O$9*T8Q(Ub zp3ET~B8=%3v`nZDXSH$S_$U^_O#6}~DfB&6bT8tm;tjoy3Dl7{3+p%S3CQFj-@sRG z2#uk!pWSf!IhwO?9!A9*|nVnZmZqC-EG^3tp=t$I#pB15fFQ^KZg zAj=8gh(GDr^o7q|+Kt`d#$$S$Wr3+SCIkl10LT>@@8}zUG@M>{#rw^|<5nGYJG3bf z?BLk4#DDH5>tA}!1n@Y4l>*z%QQ0xwqQ7}~_`mu=tZP08mK7N2kQq=?X)Y9NT(N;N z2jfb8eQSXgdv_7=1egU{QZK9DXiPXzp42;ypME+Gau4~#m}S!yBd_{J88&T`c+m?A zBt#xY$`GS_rf1BVf#pYdZIa2-7dHabV4tcs6i1+JE@N$DO;jmj+v!tdFm_;hfCUY| zufOv@q*0vJ$FlMe$=Q}xDi>!)6G`5weMBhUQIT247zHQVKlIuiguM7kFUU`1XPwP7-9Sa%4abgo0(&;7vpUAb3OP+X`lhJuvgLvDkWZMir6tQ<}2~DbMv!-kfn(} z?R2{TvqfLWw2$$p)0yeJg->rD7R|pgC9LAANV3FCAqTUGx6hvO>a$=R z8X82^^Mq1n#*!^DYp&t%oy_*!-73%q{XSWbQHz<&jWF@D-=@B)H@pYD#$2ndfg*D* z=gA7y)=yvau*clR{|u<^a&y|)P)pHLP8EQK1>qkMcqDh{lK!AL&so5j4*F<;t66z(Wkvz5BRR%n7(6wrFKMrz0O1y z%-zmova!krTw*VuUP-sA>B=spbz`9o=ENbf`UZ#G#fz+CYbz&u^By48`=917wCG&J z2`R~MFwvRQ@8LBQE>fz8Zf`D#yDIXWwMryCUz<(INx51|BywxsQUWx3=p|k4;ng+Q zy_wx^2h48409Yjd3&y~F#&gXnmG--&BL!^*dvsNUC~4QP{10zo5Vr_z-3SRds{IxZ zKDu5AUs|!?f$^i#!lXO`mbayJ@{D{#7w0<82%Q^7*(bI9&fg#UCcpj}@jR;a!`;r; z0?Y<L$5Zq1)k#Ku zu@D~Xk;~B!OsIRH%$kadj6C7T^SvFBM*-a4ylp`JK~msY24PWUiXC1~OO;RRyUTY$ zpIM2$X>XIPCU7l}x{XA_;{+UUSt}!TC)Uv-AID8hzogaN>ZjJ`=+g@{qVzu-A@_DSdV#;A5yqTaVq>3s~S#*7PF~rjPYgK4e22% zd|A2>$O#%G?gIW}^cbCfMuIwEfU-~@t0P0>B=fnHeyz{=nWowDClsOIX|rd!ckytwP%vv7b1QOpKTbyyXg%;d~2Z%NIwFR z&}$svQ?&LCgCvY#Bv`pgDwEP}qZt47oihB(WSSOO%=l`5*QMd^-hA~ZjsDM4OJYE+ zJ=k6&KQES??MMN7qhOD$CnG7`hiL!EbPMw;`W z;XQ#4y+M={!rw1il-Qtda0i0P0JzgOS=7x^*fnjCaAO-aCOOdg zx$cgC^4y~HBgqgC58tEvn+M!j^mju-%o8Yk2WJ!22!mnjEOaZWc=C_LcKAw0u5ExM zr0?MB!Z@Z~sb)sIzgvW3fqz*&FL*e>XM+mpX>lAqk96;L5&=C8kdodJ;w9|{R!yI+zCYAh3yO+ zEl3Qzzd4PAX9U&2?x7}rG}0{BUz0h1%1j*dEx+k_+6jsU7HgLLyNOH55RUP9a-9+q zq&BJJIqvIwhoTjg=%m6sg&*LXm2wmrq!qF_4iwFAN$>9kqbbpb5gq zFH&c$r*U)H_EV`{bJSSh$Q((KlotZuv`}2~4VtBHcw`nSCImHNf~trR8m4@GyOc~= z`+K$rdN24jI*Tsl)8Ytds2uWjG|(ZXBk%w3crW_-Rs z2q8mm&8G-vBuJiTIYHKM1LvA@spEA0Tq}I6#`A~ata}On*0=SHf&?}0U3g$9d$bx% zg+we@w_JA-;MGF{*fWd=a*6a>VOq*)C6M2=J-f$1VBqhzUq!ijM!a4!@dxB! zXkCY)UoMgKBJP9~bmOba>?l!a5v$ng-O=>~fw;ra(9jmfRsMwLv}^WD_N?Z58R`j{ zUuh`z@#prf?Z+?Gxa?vFszg&jjf(iHO1;)qPf|FwzUmYX^XZIQW10d4ahXv#QML+W z|L9;vX~^|GF58R_2KG1{?NQ@97a89+Z|`iSi$;_Ov0~uesYFixX4ybm-pvV2(g`KZ z%-fzASnmmDfu^b?)rU0Aeq8_fp0Hn*a4YCBTC`i5vZtQNx!D6O>Rwki2^M;U-uB6$ zM!wOCUQ3yOM97**KuFS>P~V^)S##d>Dw#3n+y_+G9^9hFMMqL;^f}1Dixp-xsY?Nk zT;u?`t2NZF0_4t;5ENlp3+wVo{LKGR+IxmYwKQv>U?8dl$yotGqGTlvh-8UM&Pl>B zAUQJv0+Iy;BumaY=Nu$wWQI6MW`H3%62`mS`+Voyd*5@Zdrq1g2U=v6Ehr#A$;3@KX)Tis?afPTr{Ai%sJBzO z&{L>7SBX5~aUXskO=d02U&C?es;>M5w+HjkR~RTKWn|JVIuP8-o3}lWbSp7Pjo@VG z&a+{e5EAi>=ne0q&XMC#CFF?X{}LmE))|iyyx|2p3II;nORHb6I@z^8o2uNM`tJTT ziy@&vc>~QRgE|a;gKO7k;RX>W=&4>xr`C%8A3$sDG^Hzn7rzRVVIYsJ*`X3iel8M z<^&SJf-8!klac_-?aklL%X2FjLvHtEZtm^vHY%OqJIVC}p9B;wpGPsjJ? z3&*Zyxb_+gi+r!6rq!0HbNr@3ukZj)JeZW%N@0S7{agzR``aCjkl*Xft#;yc9^E1j z{Prh{aO9=@yDNE${dHFgaGG1c!|r(i%13riT*d+jEh@~4gZxTZ#?9SmO+OmZIyLLA zXHX_G3hD4XwAQiIvC@9++-`x$yyYp@Z^C$%i1H7$Dvw~iW7EChh+J=05Z_x;9_8c8 zN)&k1yZrjbB{82EB^_IcAcVOmWyBjdwA6I!beJ(LoEr(T&^3Lp@`Y;(Z(WL9WSDsn zF12_t;;b7dLz8Fxn!DkO=!CAY?zX@I4SROWOQm@`eVhRYFln<=@!VU!F}w?!#-UF0 z`DX}TT(ojOagQTniFCjqd?}(gX>QW6n2(C}MR6yoSoi^^hp1xK5+2o=POm;aVMRqCjzHZ-!eS>F0w#kD8 zL^nyTy_IIGZ|;pJQspK4PcAX80MrIL1|X)GG8iYAkLzvghrb(s^Y7IPmw+)eM~C72pz2nQ zJlO8J!JOK|1F{V6y~%n7Doe{%L<{Ke0{~&{=?Fc#e&4f}&JWse32CWauqL?9SR3T# z(v`0IR>&HH&qVi~z>@JUhCzRzu*Y%DkB9v%GGf+%{l?M_50WSNVK`nr)*0UO`G0gC zfENMdJMwOg=ZbNS^RWK9cY1V8G;Fy3z@)MWM^)u z3%+GbW=AS0^@RMjDHg6-pC#M-R!;JS@^pqM%fv)J#oMh9m-*=UUbe7A3&DeRjH#ut zqKU{C|MHw;?^eY@tC=`xmGmp_FW+pC5BJ~H)P>&|#E0bee~FDxV>nF-4HTy=0YCAK z(3;@l6*92*(s9x6W9_!cU!YdLPbxM#dzN8`4cU@MVdx~V*987LgR zF;twGtk4@(M`ji9R9%dqFLSMNyARN;e|owEd?Ao3`U|n<;GsW#UsBP6X1J4U*GVXZ9C&M65{v(Y>Nn=V~N(KPfuaf zQl}n|`O8Utfgb4bEC*Nx8>k@*i!!dF!b}3tOrHvGX>~K+mmtp`T@tGaq<_1U63( z;f{5osO#6b-L8X4LPP^P|8p0anPE@RlzsK0#JHu6J5A+|qG zIh%lJf!#$zw+X4q}>Ug2g<49 zdLJF`A1=&eKC}&B?lWEoGcK&uR9utz#9TcBT@tV>1&i9Hy(!xb^gT`zkUQIKDe+wI zXSmrg<@>nYK&D{N#G$3Jy$bx48#&d#BhvEnb)2QER55jRClW@Tv^#{JzZ5%n1GbaA zPC!t0004=~pL^`@0c;1xz?Tw!3{d&jSGToEY)Vj(9)h zq6GDcCt5Qr$pR{vBldG63Y>7`3?1$}etWfrOMXFQ0)%?4{l_{^We-3{0A0gAJ7(-U z9l9N{2wmd?czQIu2Z-dZYV?18Dcs^TqD^sVD@JqnVV+H&nGk>*2>r^0kad0OK}+y) zN1?LjlH6Dd>U+_pWg)k{49VGM{cqc*jKv&9be}UYO>a!Ok=GIl997rc=&@6{Stak2L1aZWmleVam1#ywoa&wDL$$fgVL_-W{{(Y ze5Kh!>x^_%W(BFZEgGjGG$`hbv0Yn+XD_UsDfA|?nD2&k4aFMacCnwOMr1O2hQ6At z5Hx^z^y0giC!s!Arz~@Gs@3v;ZlIKU6R4%WrW?EzRhs+rp9388x9KI0rq^=&G`pPo z#lb;5Q9@^3X?M;$CO}-pEN6L-qM4^kFe$8{=R6D%?k9S%z;YaLZ*Oxb9DbAhZsWUv z&t_=SLp>C4|KlsAcKTVtvPa|6!XaB44VO%Do8Y{U@-$*zkPN0-&|bR=&5#vADaZGBKSJ!y4m=M&|Fr3cCIL;WCw}uiK&?G%6ds z&G!nfMU8wVtx_M;x;H!cwR0Cd1f2QUN%Eke0mj2UqNBkXQ6*msSC;LB-}J-g<|IUq z1(;j+n?~Y*9scNv<-4q8C(z&Kr^M`1v;>Fo_7)4>$=X}_?-qUJi3~1L@T%)@>z3P@sg`gCJS8k2-T1?E_irT+R#3aO2yQ@jdn*p9D&ipWbUh33M*APv##T zhWX-_?KW$5O0(uI2eTZ361%2d9r9G?)ODiqZaH&~0f^(jT9tWDXxZ?nP<~6+v`f@} zOm)fkYfh6{Uvfky?M`z;un;`Jqiom*;$&_5L;D0iev8(K?`QCu1(8I#C2*MU`rfJ zoqq}`?O|HYsOv;etIvu*My2@qn-JMGm#3C&QzTXy+SW&ThDk)t4u7S~7O%A|8zNhL z^39&EsGRLTxq}iC4)Vv%&M6JWWY5-lEMv6M4;X61GsI9hB{{}6i5FWRB&V(DFSrCo zAr#4Gc2;~iZz8F`#J*M_|6ueDHl!>mcvpWw=ZGzA-_-~CW?FvMfd=J@_Ma5wThuk| zjPsHTwK_2(ugBqRal)2)ayS(;PmEZ)e;&9FnzXu!WGh(d%G#x z@!%}h0+Km>rJ3K$sUNv1JY!e`f$`_C#Nk?_;A*pVyurXJ9P4{n-glh29SG!| z?WTe9=K|L--9l^js5+O#r#Is`d;hnIVb9?p+bMPeh6BH-$wko0R|7R=46_wQo(GrM zr0|9^zMgTDkkP6Yrl?Qs_?8k?>EX5DL6PE_^IB?5+1g>E*PHhpy&&(0Wy;CAL<*hx zW!$O|yhmazi3Ze)O)p*rjuIyj1`LJVLK!XtSQ`Q&g^q?H2(Cyjg!Q@)ivS&SXVM_(AMHJvG1NeR}8lmXOqqdl&8%v$B@UN+f zPPm@{_mwq7M*Hcb98_v=-Q<6n99`YD?Uh8Zk%zR#n0B+Ut0S>pJ3fjrO0dkJ+j}4Q zkeKJ;<=(29Ufx~y@%eyo_1ySSxtCz|)XIt}(lpZuO`p3*L)lG)yCca|75{Rk>~z@7 z3Ct(MMM`!%$AwnZ;W6Nsf zylQQmZ1jN-p~Nj|p|9@*#GSaRsyC(DUd1fSwyJ+R zZEG(}ZA%K4O5lp2;p!SDHl|c}5}w`Cn#_@%OR!Nw{tS0X;)u{^m;}h!ygvL-WNiOO z*>7f#ZKT^-pB6SsB zyFIuQ))<|dBq8hNHGf&#ZjvJUq@!0p2S^?|vabA}Fri*B^ri*M3PW=p>zg}CsdK4% z(&A!R*O631)sM5k1t;$_#7uSI?dQ()E$6u1(YIUrlwQhnl{#0|q|EugDt-v=idr1} zxc{L>pz<4)i~cenT2>^i6+#qq3Zen~lg#D-!K@m(9ou|mx;cc!)8*#is78&6s?XqA+J&<@PcazukeEDFu(@4pDkt7Ez~+3l_vvTJ%|1{q z`9aVT&kw}Dw_BP39QZ)`7RWiA6h0*+$%N^>d@;>sKNQ~4;rpPUAo8r#gc>uDaSt6N zx`OWMYiF7jqOAAPtKgwk%W&6DFUJlYt#+${jW$q#Z&Y)Xgx+WaX`yKW=D)Sl%sXXn ze#$C!qMh>(+Z(`dS_997oXT?8(r;CgD}TY#2@Z&a*r(N*9B8$
    ?Hb=!by%fs#k z*%qOT5d->7vlZoB_79Ifw9d8ea6^J;20IPtFW_s!&3mN5Iqnp^Hk!(Z6Q12OYwokli zN}N)W7X0Q;*y7^&ki4?0?oNwt<=`T1yQKaJQtVvqh$m@ndNmc|VaPz)H2r1zur|x( zj{8MG!_>6Aazs-2te)5HBRRs`2%C1#@Z?as?FSn6*@}%0pw{3Z$Jxd`X;2Vj>96>wmt~P;L$tV3L>q0n~mfB1o^te z&+C(D+KLUkz^oH7!=z8miTQ3_HDNX5fH5Q9T+Jn(2Vsj&h!Z&?ect%@jpuF1`gXYNyJSs`bQ-grn>e6XR9H^izYK z=xkf1xip$&7vEb_+EFZpw)=$kKtH75_Nn15c&>C$8b69J;ur?AB6I5<#d(TRBVqg{ zR7<*KzGk6H2Ybm!I>D}oH#|k8kK|%mYIM=w-SesqUPHBIcXO2q1cS9+H)-Aepe`_R zFJW4T1?e$dm3;!RyyXJSjOu^mZ8r?FdhxCJMtjBQhy?gWk!DFCk_{BYX3|AZIRRsM}4jecIy}UU>P#}Tmg_Aj$&6Vqb zcrWHJQW`k}TrQ;n0!?lOh^F~DicFpw{0P9W$wE_LO?PVR@&Izy& zHaA2?X@ffvVS`-f>GBvA zxD`b?n;eamlYrgvmlmz(6z2uA%6UVgIx*8eZ(4{~zju~1^QKwQexP?zOTeD~ty(Q`xPX6>%R=ri5mWu$`L4-+^^J^I&mr5A(q zbx7v!3yHao3d6GW;5f$=?OVZy>)C;dEbp*x#dlHzqYilB-&6R;JgeBb%7{Y;&q>#y@T2~1YjUj*hREBdzXBBJS6fwW+SWn07V!O=ooBTm4>6zCZn*a^S2 zqfljSXo|9m>X))~YElzFzRU9(XZ$?q7uH+N@-HqYYz+N1g}{SPGM5xuJ;O`XyLBl$VucWi8<{*% zzj=#;7VD$Js5d2&+N)^cigiXcw$O0r)uM^bd6|>7fWab{h<|c$TdM6S9_Tdb(|8QrNN?l@w*-=uD@so%?f7W=t zXnImqIVIee9@N=QV9s!d>Iy{!4$~6HEyiuS4#XznRkLD#;*?07yu`xiIJKf|4(>YBvd~Y=)FMsE z|DliHP41UgxQ$17C#JiV?NQWINHR>Z{iLF0$hc#+-7r$erJ^scP z8kb+(kR!_YABLoy@;p#9YAa0q_LrU#J5fmRMt>G}x9xn*&Y_qzZu*Gx>Qe>Q$d8f!#$OChbBp*_~;oj%6u%5RyT%Yj*S(dLsHt;sp={1S1bJZNTJQcSh z8}}F|E5N4lc+S8IY%h}WJV!pZWVv$eQhl;ArSj|uDLoVyK)G~A(FIdVPh_6hjxu$A~^n1y|4O{Sb6gkXfqw+PmyM)df(*8NF_!~ZZ*JCZETja5en zMLsqd!}If+ue7H0<*{)Ew}{`$JB43uHdM~Kt^{K!QI-rGD9&9}6v0M8Qm<(4Fjs{{ zQ~YR%a$?Afco*EL9QMwSI&>PW>8yboQRp+5s0*7547v#u}!y+iU8MKz^4p_3_0olwx-5|C#{sc+ce zT_5n~WZ~uc{Eg)kotlQm7b*NKF(V=J&$owmsS>R0pO8lAR2Ki~fPo6;cP@|qza8+J zBxOzO(R|-8teg4i<=|Ic#a|2cm!$nlcWHuKtGGYKmz%X@==PMxKF}ecyHf}$?YwqE zz+1{d-6^xzks-9HV9sRox7h4EwEUHK@SZY@2TI5iYs59n{NaWbGkpHT{ZY00{zZAd zVn-F>WxlTvaJzp!vy3Cs6%xV;MPMDHcLxQMO$)~JAK2a)e^^upL~wvlK+N^aYW z7xnWyx={XJUgxg~<_0jg*#O^SoRzih%^TwC!%O7p+2>#x0q-twa`(%1W$SQ7)V5$&k*d{>n;}TJ)6wQKYYdHm*MRzmJc6S*drGsX!zSdi z{0nMDdO~e^2)yCufhTWXefqkK+wNqX^QdH1Ed)cX*L&V97m#nD$d6Z5DNVe*DfBXF zFTR#k@j!)1Qul}_7Nxm@xUc5oRNy zZZ7ZG>^2a132S9Vy;4O{M%P}Z%a2SraMP?mpQXnySWA#38f~(YY%w;#Xq-r)Ri-{>y zsr1c`_P(oBSf~2h9jD0TG|zGa`vms`(x1t0g=}jMnW^>OaCwJP^s$kMh8cZ;!b?<( zUpAov&Y_^;fOEzhBCPs81Mh}V|0iE@BYcIj>#mzU`3E#HeJzWZb<{#~C#ED#%p-AA zE52A;VbrS9So00GAO=bL-4vtf9HAMQWy1x)F_X7{_x6a+F8{_N*-2W9f|MofF>flrt*`x>CejsEPhpf zmosD?GJQL?bSv8F{^!O1uKMP)SAoDvu5GJJG#(KFVqEbo_W4{}W=zq}z}ztN01Q#J zs8%+}v8yp0?4EjIt$UTw5Ncr*xq*`%Y{ z^l}a0G_Bx453E&?Zj0zgN30r&=`k*dy0xafoCGl1cFxM6RIX_%mJr$S$7Gr5d;ntyyg2Y zEJ=zoE?V>R>G4QWttpm+)c!O^An29m*2>r0W1Q7)O(~PjGZ=`8&y}0Qnwak9`bPi- zeFrAXkygr4qGQVlywjrFUryFzZ|+A?2}0N|W_qw&NeOYN2ZfFuJBOoQJ)ia!$p(vN zMsk-wc=TBCYb}L}5cSeK?Iz6F9cJpgt&0}3Rk0>k8ZU{0Y7cK;DNXniu3c8#fkFll z^d{Z<1?Q2P2j0nE6iYdyi?Mx&K&p!n>=BIYBk*@ikkBFSx)n@rPQb?5z0&Q$1z8 zQIG47`0H3l*(NU5%Lfev^Nj4+vlvL{k`XS#Qj)Lg1pX>;lVUH9a6Q9Ae79tR$C zTg^>c_9V~GhOFI59L_D#eJ1Vf*`2--NBPrzqSk#rgXIu6d+ zdXd%2R&KGqME0*s{a)V*zt|0b98hI}$-P!FEOrO&s?E#}^&12L#8iKViNhs1(IwJH zp#x<~(ECPEYnftCu)KjDt31irTvBs#x`kh!UiST7!2*fP8}-J;pEqk#IlRRqPw_-|OSa8$gEk_KWzbbioCw6$lsWD*adY)fH&yLm7< zQ|-pE$fv^aa&rx5CqhbW!Pp-oJ(Jq#ssgl&G1{vFtSMBVPJVXEZ%F1I{_If=l-p>A zIQ9HvOC|oQvs8u4kz=n39vW3qb|vB2f2qU=uw0Jt!%Xy&1Ng49=iC%?NN=-Jf8u9B0LG4M?_skn-*Svio%2wrKPk*_#s)mW$ZR3rHAEh_zW~aT(wbdinuFf3>*8X+GsVX`Wh=AY(;rJ%8)dlcrP&zLvxV7{ z8ZM4M=$5^sWV95MUC1(#V`W-mG*nP4Nq2M@B`$(ki3Hs(z2bF+l_oYvUH5Ln9$P9f z(WvbE+mBJ&+gQdD3w$JOF8ulV<~dMd>|vrGu4W(^FH7Py5RHiI!^0war6tadXv6LI zR@(EYZv;oYVu=+=Ci|#NSH-8u97mnN5;Yq_&|JijFsb?enO10cuV~Gd%~f$dYAg5* z6tqltRG;d`I#+*WQyNOSlINS6PK~Lt{;HYi^@kjUdb2i zyj|>L3lVxWb)eEqM9}*6&MS)dl|R?v?AW22ME3U=H1M1g!uv$agzG`#+>%V5GTr{H zsP<5Vk*q8H(vif@`RkqSZG>2uFFNf;Yf10!uADS(_yv5GzHPRkvm;cvVeUk4v1hIA z0c#*aB*Z4&9&sm5o=HdkRSzeY^dJ+>4gxh6E7D7w88Uo#{5-Eb`CjSqXU6^Z7bl80 z3NvZNM<=V;Ro#yT#>6<735C1cAJ0UZ&kgi-PyoPSAQSk?b%wFG0lF{=R6CoxT^pD4 zVN-ZMB`K|{eX4ZwXH}BAbM^}stl>ttT@SziG}y7?sQCHapJ6_)jh3K^^!Lg9EL zv5s2Hsk5I%Hxyq5^?4tLNj~%x=@xBI`hpJCph$uZ7aPcLd}R{LRd_u(SX#xY{7|mq zftIzElrk}AU6#!9CO#rHx5&U^gBK>FHw`c!n1rUBH$14DjspM#K7jx)WsUx(kcay7 zq;GH!3l>wC0KJYILQ4XDj1;X}CEKJD8}CIK5^frnYT}6>)-G2?Sy|zY70d7kDurlv z;uu~7Mn!>!cEWBCgk5PkLj?DP%vsbq@ynaH!~KUBbdM@vb{dc)4~-csZQ%>|oUG+V{v1#O z#I{6V&585&t@q(9wsh+R#kxm#J>^4ut=GmWt}M}RwrdfIvkX%w>Wz)UD8+};c+)&N zCy(B?SjaGE<2YlxYf-ZhT#kj)@k|rw7SK57CvH-PKS0bTT>MnlK!Ky_=pgkjaiOqjykdWb=S@^s-jjS46bQa~9zExJc-OOLZL_Uh2_m?W z^r{{lBG1P%ZIB}uy^E)b5P`Q)ciqh3+`afNZnfS)h`!N3uYppThRgbZ79ew+Ugdcy zP#(qBMnr=m%bUmdKi)`Nr+1QxPQ*jd6d}bAd^e6L4X1t8P$vB}V}XYr1q0dJbDKNn zAEn+Hq)WqC6FxDEmT>;d+AnaCLi{jarq-E!qH~Ct)&#GE5lQAyeeE3AR>m6we0(y6 zDKQIzaNX4#V7Oqqj_SN?iG+)MqjJk#IjRHh&BzVb5UYF91&vlrF?JJKazkC3EeUEi zzVvG*)Sfom_3EX2x0>B6bQ7TC*9cC=d*%Nph3YjjXv%mjaY!5k6V+ zlsBxai^@v@(|&p(rI(!Y2n$H|$iOO(PGocX!3#yLsOhVD<(Mh;_nUU|7i z9$>Brp@fFJ@L z8-Cn}Ka{5LjbJ(EG^nzQBo$%W?QS~{CKP**QhlAk`I~%6Z-{(+7*mJIZw3m)Hn@Fk zjx3==>EPTJ?L-+R%oVIbEL5x2_z}Bf2=;j<2al)(IYlhXj9f%BFZVC3W+;bv17_!{ zL>(m3PIX0sj?do9z{&u?T_QBpN1JD}j&}%~GH!z)XdEHCL&-S8U9qC#y(0V$V_k^v{ct zVbB3V@nL`oSKUV;#j5Q*CKG;{n1p&IJSdkxw zsB5pC5+d9K!{HF>Y)%$lH%`1X!2sIf#kW}Fsw{J$Hvqev%F&5#%7Ap3uZkOD!Ac+9a^Vb`eBuefFCdtlDCRrHAZM0JS_To;RsU>ZJ#FdxwLKDV3w% z5aHT-8>hX=m|X_)s?9A)I#QlLk$Up;rY#3;9Or4xA03PNb55U0{Tx_b>7hLg^_M^` z&cR-KtA)Dhx~f%jaZ@H9!(q)vPok(29G^su42fiZruaTT>G})nFXmA6Zs;-QFORQ& zlfB^QAN|2786fDVqaPc5<)qUge@b%~{^Z!vrc zKU(rQcS5U`jpvAaD@-9zRpc)}X zW2E1i)XltMklWb-bo~Kq5&}KmiETwY<2;8X#@yH4Vh2&y#Ig=AEvVJM62$S#Im-!z z?`&tkdj_cT`My~ejoMIot|D$*%?@4sN}%*6vYO<_Wb*k(`Fou~{i5FNUx-*y^Iae; z-xny=;B{Woh1!&g(4n8CIk&lKMv@3=SvKzr6Q0KwBtKXE~cYy{U4~hWqdX!|NnI!O6#i zW8dYYQX@~D#y7OYL6l);)ov)2*}RQ~j|o1UUfek(tEOaewqf>LY;UPL2dKq9sbd@d zpL_k~Mk8K3ji)gqx#BumvT6j+SVXOGjS33(g?B7|3zk$|+x2 zj~wG&B_vOf=looYfz5rP+PU1y1M!y7^?BSn*VpPeM8U1k@&Q&38u|CGT2|Ai%3?mW zrpUHl;mJ~s5Qi(uFbARla~|e)sfSA0`{*ZZY;bHj@<;xg}FVm;jK!Cg@1F&uH7iD~&dkTI3hhMn!Bc z+$dox_3?@PdFmHoWVk@Xe)jB7TYnFy%^NgN7fK^?P+L9Sqp>C!k|p)3j}$nd`fm)tx~@O8D!i8*2!O3BA1z>=$K05P-;3oTs>h3oCBnzf1jq1}5Q&^mk$2%=jzoo}JUZHZeLbT4UBYs4P z-$pEfIvGg>BbMY4$B*CC%+wk^{Gr~^5l-hodY!JTx5&;^EoosN7<1wUKCL6;68-z# zx}j?jaGdO0t}E8^MWqq(pz9bGj*<+Su&v1W=+z2|VhxnYz?7N;#GhvGq?_2Q$031G zU_CAbq9t7JLvCTbE=zc?rK9cQ7naFeAL|9b!Josb=sJ`lPNBHLQIGxvB4o=deFkh+ zlv}@g!)Uc~hArLT8SV3P+;&`9Flu~lKP9TJJ`Vv_xzil%12JD-S7`Ej>lx*E6+h3& zGbiD7dfpte34aN?*Rk;-6h4Lz^znI^+L6*V0n}* zz;}@BE6t6p?nXVuMXeyB&u^s%glg}sRY(g*N*h!77&GbHYL|YRr zT-hbTCeV&+hYfc26>cKuM)g0Faj;!#;&7n+-tpSdwG7M+58(d*KE+3&#&L_Bq$h9o zdY}Hy1#c5Qasns2vWL*l6 zK#t)fE72_t>(s~w$+Pd?G5ZUb;&yX8;KHjXN4+Tdhe}zsKA?p4XWCQ|fuKbB=d$`h ziYN1$0a0qVen`dI8%8(RC8agU{En@dvK{Dfd`I#3K>Wa%?uktL4W0&3e(>)nGt$2< z@mKyyw6Q-R)$ZV7-M;Um6_-)~`)-xOSVPwRAgZ+!y*t`*%eqqCZ?9KeWi1mXB0u4X z+XJH6hRqcd`z4EO<<3 zkR`t2UA1D=r=a01x%?(NUVRz6BdOhSt~+B0yZ$0+)EfHj)dX+qN55ebr1l-}3YfK- zIHNWD3*`;rGv9kgts_syJP8U4B-??Gl1MV?G7Fa+4QPL6XQ@y3$v3QE2n3gS5v{(z z)M5RSN#*M|#a$OeVJn*X?uynGpw~7t!G@HXU@0)G9~<@6;xg_se{7A2;hR-{q2JTz z1_cgOpSb;grwq~XzINdBzo+A_0}~khCXrsI*PQWG!IIj9ql-Sxa4a;G?Mc=x0ZJUv zu}P@)FfR@OCz|g;x}Wr1B%EgWEp>VUIX<-oY`l-M$a`b%lzy+`C6!?sAKHCd0j;X6 zn%e&|?KLw)*YFa#{u-gsVi*tQZ%+h zcZ^6+b6;P{3-TWS&Bu}FC~I0pg#ehsc-EQC52z2)0Vpaw|Mna(;H0x_-E$O0)uzFH z7o0xE4hH27_$*QCF_Ag#-pk{<6^>QEV)*DA24%nW^X#pU>tn-dP9JqLoM5fv7!t;a z6gEkE)Qo;LCP~nFlQ_ z_cm$Q%z-tdIz(n$1@w>Mig`t{i}7or{UDwo^~i*^L2NxPYSRqIhj?JQ)$7UwQ9)tX z;7@^eVXZr=PgrZ4_f0gQz0q`+%pkpQ36@dV7|oBD5M)M6DU;d1!r+=}q+>Ewrxx0N zuUTa@*Ej%NVRA2+S3Jqd8hhI!T81m=P#pZp(SPc+Oh9J3tk#caij$W1=^VRh0L9bn zC^>QsG9dJaMV3J$u>!1G;}TKShXlB8m-udTv+kG$aYcXKcJa zUs!J24R^dtIeUG01)v3*sftJ4FOz;UFcJ-n2giY7?ol4=M!d=`vzyUl<7zXpr+b$X zU;TC0qs1D+7SbDadkw-KNPi{Y-6SIRi)3o6yVSl)mK?$ukGL*uZ9O~eU|8DSI7lt9 zGnPM(D%CNuO0W(cm4)(HURFO^b#MN@7Rp>_lz%koXYzYAF|=Pe{rgm(3{BjdsriPwUhc+>vyR9GN7jZ+&&ImCM1Gw?tU`;( zwADzR`bjAUd4PSVp{wKc9*P%3ngVRQn5LlJ2S<+AH(Ql~$WcB(6EP^XSF4Wu5s0!J z9GU*Y`s@(~J^D2N3+v~Ncnl>p2((QL5cja0{ldaXRhVSI6A!{rh5}Oh|CZ(efM&|1GG% z$o#7<$6(Ml5uo)*2oV0(3WomI?@Y{#e<@i9cWmGNLkSQ9O4h+ZQ-28DV*a8P0@^10 zt@T^@TMPJ&|Hp4AM)qG#nS?<96xIPkK*=PeV;lFkQk~;pm4NpDke)eW~g0^q} z);cr6$o;D+o&OTD2?0VtNhcJ#eG^b(6H>#-|68f!-$Fo1Efi=P`?r=G=-*nZ0l;1z z*gbX}rOyCH!4QKs62-kH#5@8F#^=>9tab8zV9OB${Mw#?ZW{ws7x!UFzp$u)2IYXU zc>V*pGIIus?(?}#hQ0m^>!?`}bV2E>0$3j=2S9(@0AEm(+0coW-#0!Z!U!05$zpjH zgh)QyHs(M_iS!t#-9^U@9neckNG-3}%nLXBFEG3~jyx`evmgI}SgCh~^qW)h*8Tt7 J#z229{67RHD}ew2 literal 0 HcmV?d00001 diff --git a/blt-figure.pdf b/blt-figure.pdf new file mode 100644 index 0000000000000000000000000000000000000000..045f15da5409ce99176a116868de6af773c164ac GIT binary patch literal 62517 zcmbTe1#leAvMwlwu^26~m@Q^zW@cH;%*+-uTFlJMlEq}PEEri#7BkZu{pZ|w&)(gL zeY-Q!-P6@sos~J2MVVibD~gKKG10RE$ceUhrsn}H07ihlkrjZK7r>zE>0kolZ${HG(+WeEkAn$7H4A24P{m7>RU=RhJT>wlUqAKXQpObuP^od8sC?_d&O_Mophm<^aI zHK<&cE;gqBq>F-ig4ux?f*FCXE?}l$L|~F2IwvqoFfY*64wPmH%JY$G3Ce8&`UQFe zm>cN$Nd2I&0p%nDll|u%AE~BbCZKy^&@&I}|0o#Se-!LrYWYyWhZs(#b{~TNBb1z} ziKU^iy~jVAU<9zSF#=fF*+Cj$5V5y&0TDU_*#AYTZ0c<9>SPSk2gkn>gzW6>L65oq zbtmCu@9OYhx>R*Cv~za&U}WqGdeet?P2DVwO_e2tKT7!D+WUC-fBkpjmNqV?P5=gR z8$%aUQBz}k6HsYoP3_EGEC8JUFahOrb^+}Hs-jFr!R1ugz)fW@;Uy5 z-1hhW>y?F3bC)E0`vfty&_eJGs}u@bk=Un;-t+LTp|rK%2U&;QEWD?-^hMdPC6&x9 zekoe0TImVzE(=lIR|5W8?*<}y{<-Uf?}uer#V$x=v>JYnZyPtHZ%faG5nMlX>A&p> zxLmC~e{XJl-+X^x-C>CBEfs#qWq5mj>h*t)J9@}`SDD~?yX4?ldu$!~=70a-`grgZ z(|oB#C$w@^nunELmPhBBt%i4k76+{Oi*?#XIaw^)i$E4C*~@~|G_N!wtMP0is8?G) zHrPg0cv&;%-7T7=HL*g!qLfeV)%s&p;6~KgF+ECpC5+rsgC=84QgyGYO}WB?Sb1-) zs@W_@v3b!;WO9clKSNJDXrG3>f}(Yq%UYVQm`{W!Uo5mkemDs)JfRqWK|4mizt!pY zuiq4YcJB=Jud_;o{_e6j-`;l~?%(rbd#!%;KE4G%^0B-hCatafHCam}#rxW3u?#rp z`0c9PJ#7|oX%X%27{TuXcVtn;vS2_re)uCfe;o6(o~~3huz@QqwQ13E_E);NR7b6A z5f?#fJb(flYoc{t>3QfOXsn$aI#HQfWT6Z=U6vxbH`kED7Zt zS{F+b?DjnxJRR*7F)Ol^c88|T6=--d0PA=HPo8rDWwflyq1g_$eVn-ET?&9%EFqGZ zQg+rtXB#Dm9?YzlTW?EJY8)>Xh9N@ zQA$&b7(m8k*gQkpklw7x+*e;s)K4igm$vk;J{FyykhRMh@);2M!X3HC?JBbK*FcOw zN_*QI;fp;kv$6E$I!rVC27M~Z3C}SccI(v+I}er<=TS zB{@6OHlu-!ZQ`kQm1i=yhipwQB1{cbB?|~O~JydD@206FRy9F zgvuwYa1)a^71=(z1bb2Y_vND{9bzDS710r3Kj=72EgyfzE5?oxnlRuT<|0()I*8uJ zhS6(mmjPl6(NU}(ot=G1ReW1$kWG^nMM&fl84GbP5CpES}IKeUPNtIl9c%A8-l&dd>`lx)*{*1#q z)3Z63jIufka>N{*B!)8VpD;WSf08AilYxCE1OG|_A>NFaQ(>fbIRHy&33WxlF4$#6 z+vf8FLmJ%H9oBEA8@;2|Rlk&=dm6F7J6=73TeFn3{S(Ci;|YQ{mpG|r4xJpv9GQuS zmz5lX}(oaRB|H0zH*Z?Svkmk84Uf*^xaG-1qYyv!p(yOi;8 z>7zJViR*;&AMo_{Eg{084V|t04n1mFt*#2S_}%5yCop+wK!CQy5hfRV6bqxO*ARWQ zYxu%-U881n#G-Wp>5>%AMRr->uKmMRXXwKG`!$9$k+MBC%$5{{(2?y`%R(y)rdS>l zVtLRqoxFeMG3(4v5;?(&0SHO8uh41k-onHrr)MR|RrUnOF>5KnQ!I{0X%I)?W=61j zltfiJ0Two$6XLMN$b7e;gU*hrKkhkpH03*qU>3UUOoZ^K5NX>jfj!1t;!wxJIO_A z5<*vgW&y}J+Qn5VX;}O5MdWix@oTKZ0Evp2sPzTF7 zsK_@`IgX+S5_Y!2M9be%pN+Kdm}Ntk+$W5L3QoT*-30h(#M_^Yce(tbH^VLB0QIIqS~?1+o_hdAvjbvD~y`r*MONvu$}f-$);D|N1?|as*UCJ7<%je?Jg`=5j{%R z`27`!M$U^QNh9P@dAl@r`YZW`v5^i63gJ?c^;udcEnc72ozuM>ZKG2Aldjd+$VFVq zYrEDCjBkp;FjZpN28FytiZP>*Eb|o$vl;Mrb;D5JKM-pP_lJT6|7?^KHk~Kk#bmEm zG#JP#G%*Ze4)p;O{kOT+0^{p5x=HLQ78gT*q#Z`BW75s~f9}ej-0ew@!yssKrO`q% zj&F3$cw7Hz4wd*fK{NhJF~D(&qMrn*QAMobTS-D|t3d$|P+*oIw11+LuO4Ie?tT{* z$Rp@&-8yOdxP0QJ7fBo0z$Z$bg7t!?fc7k4=+>P@9A*k{2H`1I1i04VZ7U=WXZ-@H zZ;r{OKR-xr$trsls+2E-CZ+>Uz>pAdjzak}Z+sXrNp5@1udT1J)!2ylgwY=55JqHp zt!7lSN#k!1^5A{&!?4|nvX&0kfu270c4PtLJ_9(pOccu4@W|7w-sdFaUEb+6(kWV_ zQO7<+3NrF58Rn@_4>?R)NM^yiL3Vmhob8x>Mp_8tgn&yF%HceuRTaIUeP61d7EQ)x ze@+=$&R_=2i&imn$!GNqNjD%bUy0VwhWQ3rRbDC^Qg!KZPEMhnPYwFq(;7F^ZG zhN#ntvy{q_NVR7ThDAzC4adCW%TkDE;lwXP!;je#O8I41mKM%Gvms8KN|GY=nD9mM z-sY=PhZvE5nf4^4iVH8%6w%R#Db(7bR?q-kLcPdZRrAPM-SCE{9_3~HJxiz@@V+Kk zqP!e=Wo<5!Djh?S)1B3S)vmK*054Wg;XB`e-SQ@W+4i%RE$}1UmUWfUGjCR6Q=fl8 zBaLTToE5~F3rQKK*hYnb__K(HIMRaP@#-fu>yRZaI$NGdDzZ~g---9V0-hI&8#AB+ zKJ}qhTB9O~wXA?fm9iO_btEFv5GM$u&74<&1Eb2S(Ti0Dlo7PF@AXCLA?;^4Yu&xH zUiY@Atk5~)u8gT*%W7kxCB>-P`f5tPSEkrZKbXT|K-5J-Q8|tw)SFv4V=%S?rgrQp zJg`7%q)2W~omK<~ky85Xtse~Inh2w+3Vl?ho}}W+Y%Uhi_j#4648&Cp#PxF(FzP?K zVxW8IwXl->2F){mi~qJnn!t2$%mAvWgaTB}gZE{@!|ipV@^=N}SG95m%5o=pR1T8& z$>6RvPqIpaL=fX}Rp&hV8`Rr{vx~qn)dIQ8e{ZW2qEDDd_d(+_|0uwOQDIc?#gGMl z0`V;M-R#EhVmPnQ!stD%hb_Vb2})qO52tIqxr6D=WXiQ7C_RvVyaEo>OG^UHpLH*I zX7177zp<bW72h@c@?ZRoPY(kSvV4mW0h(l*AfeKqE=ntes_ru6-073Z2OmAA}l0 zwOm08o#a+Tb$V8h+~pp0Kg734xqe^Pm>UloZdb7v@fVNkJR2_Gq`zVpk}x9z#ip<} zQ8q9Wrdz1s6*IA1qS+vW7!k>ZJT4lKFC4q~M#<+26qUtRRsG=VD+!&IC6P9!KbOO69^6-|zM%qk_MY(2`tk44$RQGQaXja9A5`KW1eyih&C#(s z=UQ86E~#L$iXbQD%4)7gA(_A#4?$8GY0Gyac=03bAnNBeCE%(Mq2>nnBKG~vNtONc z>G$2xHhH6+A9{+mGX5-lF)1bXuY_dItR3X=pBGx#@T~SB+X>k!O!!ffxnT5^35mvk zc|uyz=f)-$Eiiv(O)Nu*Cjv-dROWl~X({;4I0&M{NZ5e*nH{A?(mlTlt*r79%5JdE zr`75!`*3ToStDEPu!EYVXCRj8V^yYQYEILxubvlqr4Gi>pO#7Y3F%bpiZSv>@INe^ z8V!VSN^_je-Z3sFfl+|I=Sy}AT3yvm^iwoh8rj4u2XIXFwmjfsuHtu}iQpNHasy+N zh)SN^p{$X-s^I8^lLx_a^I{3-_Cl8+4w0Nj!$@r6#41-ibz5THsKO zD{L(<;KEmt7P4c}VD2~UsS9uc8s1aq>jflRmE8*Ey*#-3r@_b6NZcDz=s&3-x1{VA zOQrWx$?{o(?%)g=!%$!vW1{-H3K16E?)PEOc!plDzSz2G!Izl#f7oJn?W~V<61aH; znVSbt;{tq{v)CFJT&TVJFz=Z;#ZSUqZ88xo05N9M?(&+0Te*8yroWl7ihTZP?V^Cp zMF_gBjJ;FkyR1U;%%h}n2wjHoX3*sDkP%P`MzYiU3i3MJ@GYPlyrK%82d|I=q+1^C z$|~E{7aq9Ew5=bL_i(;i_7sVO)h#}` zsosD^B?dra{wx4uyXlDB53^Y#Q>jsTLVWPb{~8Yxw}msGCYhv zF7B8*+RTt+ou>p}Z;S_q`fnY%kc{kn{R+g2*}lehTn#JSnp~`!4Xz|5l)W_j1z*L& zUi`LiVL*VmS+3R@C?@qY2wXQ;XIKs-llbLuJB zLpsWE-c6rsYQ(Z6V-+67z4NCd@oo4|l zc|5QsTsrmMvuL%(3ZbIbWjUNP}H1M9RQ-Wl}qUym&pNWJR|^Ox6tCWXH5Y#oHX-<*{9 zYuMYK^hb!_3|9;M72UGMS~yZ9#1Kr#PH$FYt$S#5!-nQh`wg|z?+19{;Ifpja1z>L zu%pZ)4eLuVcM*U9N7v&zbP_g5x;5VhkbKsq-w+F2AI2aVxdEV92+EXaar=R9!FCdT zmPd`}2S6Ql`{*0~AZoL?N^c8-yGxyL7aHaLW5z#cDF8i?$q1=Xqp$cdgaq)5So~qb z4}d;L$aON^??L%1=225q7Bsnq)G6t3#pb27`a*Y#Jhke?y$RJ1>(w{;Z{`w#HVXCr zPc3(028e_ddz66LavzEsqM*LLD_@MD(!VsO_cR!~(Zn|0dHyZI!kQTED3?23d4;n+ zCW?G31_qMq>CUcVm^Mi6%Mv7eL)7Q`eTKePs!!`W1SyH?dj>z?MHexirvJ?in|($_ z_HwS(MRr`e{&0WlCy&y6dPhNI-)dVQKY;x!1?9Pb-M5G1Rv@<1xKEzQjX4psSiboc z_gJ{_Sv^a3L(G=uVc-NWg7(_uGX!9Y2vgbO_^o+?rga_PBDm$=8M>hra z@uW>HV5&h8gEUZ3v zwX3k*HucOpt*lvUqhXfHd_B2O%CqCLI?^yvJ$~?Er=MW>K1(u0x9+v2cSMV?f~mIr zfK6!Nj`c`(4lu{qq}w)dmxU!T?|ZlU_GQNmn}i{Imk(*}E{7ygN?QmwH1(7Fa?W~Z z)O1E^UQdcbPw{0mPS z=d2gzcxb)dqv#i2no6pHUpp0iq0?q#sa?PA4+4^+8{6fr-6pObeS#l_9Sw z2+Gm6*CFmQSP!Ap@u_Uszf)y8izAQVp{lqgG(06H@Ovi(=HKl6!pn^*jp{|I zYcHoJUv;cxHH>tvCpduF&T2P1u+1%$Psdv#L1_IAU{Jk59^+ zFH=&3%&ky2Lib9VIWDr`1=nxk=n$OE5~{{`AJ(Zx2sfpbR%PR0)8&sQ8wyyYbpf&- zJTd>ibwMmj5kZYwEK*f=c00V26F>&uOFfqoUYR;c0lmHeA==ec*A9l+RDdXc;xk_I}(O%envtu!_xNI&j;V%O48?RR=5uZ0bzL7%gyR-1^yxX zo?#%KM3Svsf_)H^k5y4@Y>vM5`1ymWWbxICQ&vN@#nf-!$h{~@P`6&6RS>LgeE_ev zt|{**%cYK{_7*f6C)$LuKj<{rv{&42+WXUDaz@@FA-~3gw_U&L>t5cXRg>}34&>u2 z$;NVH*Ghei+Dpc`VL?}MO(T#RZ>mC3;4c8dm@A=fJ0G4}1zJufGK8~sEYLj=W~vm< zoucK2T~>&ZMW>UKkn)$TMoJ?!thoyRW78AIR1!O9Fd&Pvc_r5!joZtZ)Z7Bu0VWg7 zoH19I!7bHEadowsaONEENrtDDP1+HwfYWhZc}eRui@NG6Li$#6;A$(nOTuH{Mh}oh zM?EE2JEostob)RsSEXh5`#mYrRH_g2Rn=8Ux-R<_t%~4OD!xSxrh&MnA{Hu9;j)c_ zYs7*l0@B0`mI1;Y*suY2yS@!b+?G~jfq)H$i!?W@(uTc3aZ96rB*o#We#~+eU~ zr`!KU=ggis((k=TnX0jO(x2F!Z71u>$aZWbagsIF4&@rK%=Jd5iqwi9F16FSaKtg3 z#7-0pNTzH~&jl=XWO*@C2{{KeXhIPMpYA#_H6#ZC*Y#hC{6+@9{i${au-T$A_h2kZ zorG9Au+@<5U?b_*v+SDntc$9k(-MzZTS<2%83vtIGy9b@(&*Zk50Qs4{=6l>&8%Ta zgjv#=xYUGcp+amTfzZOp94n`ZRTzU{oMgZ3wG7xrMR?7=&>ckcR+d$k*bJJG1 ztMrTxI0TJ|w+%8d2`%`&rTH5xS>nokO!PKM=YHe-taKvzHu2Z83z<8~zkekRX4wGa zyjcN#S@EhTo%>+!XJ~^@$zo5Sm}LMa>d%7ZU-GUIw{X?quo*!nVQeVX%W^qUhTlBB z1c9rbgn8%4H!?hq;R&)%?9=VUHA%}Wp+6dFHuwSrX;zsAR%Y;1rd=LaqDR=+63tC1 z*(yY*}3$bC>nNRnaHQNd{=39RL( z=;zbmEnyWgMZi?+f9Uce(24~#wI1-pZJ=(YxvlO(PFGlSD6M@|SluLrE+U&J*?0(s z+Nx{#em3$E*s^J`+Hj#vFNc6XEMiV2O@eYaw7tum5+c-l7EdyTLUa=XI?X(43s9A1 zqnN+VTr1?h!fSn>w19U2G1LGtq)kmB0a=JU;g55~YZ)E~w;XJYskeoXnaxOs-Y=`! zq@5%_9#*TIqFvJSuepVG=Alcf=RQ)mP_4z6vS|W{Ef|R={NKHVGt0$$dU%*8*5oS(&@-ef4wQJGQ{Ig4&FsdK_+q|ur!^i9yXTiDxH2j`y zz0KiGdx|JbSmbvYzISK|Qt*z{9oNbZ(tZ~8OU{P5(!y9KB0ZHl!zfdCrEdC(D<`~Y zgkeH=7d$aEr-%^Q+4%_h`VRZ|+n-s`lHt}qku_)`$2E_A;l<~UI!N>`1!(SLV4Jn? zyfKNhSlK{sFPFOpRQh!ExDedx$UkDLe>Y&Rf0aOnYMF%PV=WXGlTn9-sZjFp6VC~z?+~=U0%ffMgBLx6l zZs8cg$xnsmjL@MF?}G6VWJHAWPS3WHPuyc)_Xv2?*qqGs7zSO@pq<@(;9jChM#|Ez zv)QMp_#cmLly5q0n)LUJdwn(ub3cWiWw6BlVe{Ru=GrMdW5|YC!Qq#)28;=!4T^~N z6U8Dj4m*_y43gTWKe#!|dUURV*>Nz`VOsjf`8o8kOfj7G)2DoQax?49uW6*=*+H`Zis_X9a6!O+|zS#Y&i?h5=hCttwVg^X}F z=fN5mnx$TLt9N5YeP$He@Y5Ahn#bd(gEuy=6dYtcQmq__sG1AytSZvlU&m;<>jM&D zxO_zxBw@-qb~SotDq@vcd$%_yYw@2_CnAa5x*q~`tv$7r09mUFmD(!I*Jy|4M9t&5 zqRuotlQN&xXI!>PMLaRV2=%pch;h$k*b|f<>9g`;^mO{r^zil*JEIn8nL2jm5-uTE zb5U03vm-eSRN|1UnSjY;`3E59P<|4`N}mLO?r&F8bNz5(wm7Y1ZkbU_(yH*iB1HMQ z_~39d&l(rfWLd7*;Izb%(9sg^68SdwVt$?W|9#CIsuEgjt~iSCJzql zm_tTVkrL9T4Rkz`ncXa`CX$qvmotmD)vHz(19{37lHG1R9cz#$Iz@2ApD;}(`0UAo4Snc0{8uuz zOIV$e`#{AJ3yjnv$T2)zNwiP0c}oEqo+>sJOw5v*XccS7;KIOVU>lMZw1!*+YNnw) zN$-l#Mg`=@9ib|W9h6c~7LrM4k7gNSr`@Kz{L{xvHj?9amGvlpUxV-E?>k+|EQEw? zU+DH)ciKE;cEZG5=YJN{sDwN6Qfm&(Ji4e!NK&Ab$3OiWp~(r0WzK8_mU{0*U*E>f zEzX|Aph%EwO{gVfqmlTaOUFUo9Y~3q4T`C;{>IYcqC`7n-UCwuu=i#;`J?DCU$KNo z0)&?(J_3jbABQ>SRV*sRzkK>JOdFT1;3C8 zV>y1#wc4IX8?P@XNNj>6jv_~CSVpKhFZpg1+SJP;@OcAsN2jBnjl9=51I0cZ)DzXS z(L04?gTwd-M)8Woh0wLXHqpounOpI=JMO1Gd|p-oNB! z#jFMd7$k?(m4$y^F!ZTF8sbJpjY%ciGK7QriWOYhX_y~Pe0n~J-0v>>kV-F`58wTJ zLN*KRAD65_DFK7E&{Cuk?iG`&6gyc!HmzE&qMbE&S;G?EndVi+=e~)+Byv@^ewjjU zXj%*@aQ)bDmujcNPmwi zgHGwz7HK0VxrWXM-DLDAo9A!G6Kf6dpOi6X<2Cx4g}c@r9%`zy4fK_&3*5*NlDqc! z()Ze?rj~Vt#rq<1KYLeX*`CLRF6$`pD_46MM<)tmVNwTkqbRzLjwmBtO6j09X+jD` zt${qx->4*-USA);pZN_p7U~-dG&&>P%)0YylDIX^MZSa#2`!qt`N&?OALXg+%21T7 z7AS?r@xmlHSN|xE9*JS`_dbXVYMjQTQjQ@1Ru-^tQds}HBT!sQFe)-nhybh%G^ixc z6IY>_8o)q5uw(hjEBon(nD=-7*G(V$Ec?SLVzV@lB`D>8^tuy|z@R1cC4zvMLr%;o zHnm84k)g;_?3WKSm_?TRkh9;M5T{k4TudjF!@^>+Z&rU=QIXYmJblHGGZfUC@PQ*A ztSH)5@sKGC$OcS2qE$Ou55^9mtw)@>H>4LLZ&(xi!v>}zRe?3g!GFD?ybSgRf?8o?u>6c|Qu_+8ng6zL$a^yP!Sq+Z) zzuV`b;V+O}w`nWD?U!1;rs@wTNbZa2Wa7cVF(0!^Dp<%rSU=EM@D_QksP1pOaF@ z4H3$J6Gat-&{{~~R%(cs@)3ho=%V$6hlSjm!0IERjUbpKSri@BKkxB3$KPDt8o z#}(TU5eiJJV1vd#<;E$o4e2kC-`C^Sr#w^y1_o#da#AZTyGx+CH1^`<^d?|A=si_g zWu+36vunXIsuuERr!wEB{>r#wz%n#H`Bo!I)@hoT&07&Dnc^U-g}dkHyF(#nEFL2i z${xFhK3>4A{nDJ7(Kj?*{0NNdEzi8BrycpkLAeWBY%PauXQ#e8T$dEraq}**4=}~) ziXwks3c1UWt~&<0wpSj|CBkIUPX(Z<#Iyyp-f=xbn^QLlgMY1iTXi&4lUU7y_taa1MK`9(*TocuN8Zap*v=W%sGa|B zju>UayXDk3%|%@3r*?7hcZ~S@>xj`q2&b4GcfohQb3NtOI%-(G={R?w(6OTcTQxEk z1#4)FKSms5^5_KQ*6{YI+3Jh6OKK8+z>{~!@nlVjQ~AvMeolQpj@cG=&-rzZ?3QLN z=j;G>NkQ*2bXz58$lqaItSnqwg>2}yOrwL|%SiS!w$adtr!6^aA!y#D)t}VzXO+j~DN^0*hR1$C8?4r@DV4!;BqWpn{xPXsT|4l-UQwUnCs&V zlWcHeIlhZL*?=?ni2LFRj+?nwMr8nF7IHHD1mUN73+ZQq_3j%hoUB|tz32xq#c+FG z(Y_+T5?{NE2lMQZnybxw0o4`wvTG6NTb^uZ|7yxFeBrpk5W$TVsvbAxK`3RpbA2ax zk0fBB@4rAL$wD8v#v%6k6T$J4(z3wmgw+zRY`+38T6$=z2J?ccR@kN8MVrPWPXrCm zScKU&>-QP_egWCIkN(w8g-WDJDU&Y({Y$_()P=1SdSb&KV~;28YjKBsg-x+S3BEU9 zy9f0}MR#M_Ol9(bQwz^#hIH+OD%!%RUibdX;Fm?-6f!F-u1|n&vr`Tx0f3(iK59IHxAVwn{&FGUl_6OhT&H0@XWIbA- z$ZlNuak`avN_d7PJ#0hd)Ram2{4Hj7i8$C*>+rZNuuAeQ5nG}*%k{P(If_AC{&pdQ zP9e^(B$twDt^QDM2L#~y6ESPf{WLxiRy~z^1a@wfZgoP|;kTA?mEZb0?q1 zd>K!~wM^pVJC#O0yYK0IRG!<@8}yeB+z@;5+NRS0CJU0xoR;S)HarwJZ0^05bOcGC z%j?rfOjF$oN+=N%c3E+uTL5mt=gvow;qp>gzJ7JLMtPLU+@lcd(oGWz%)l|)e|%0{ zmxYxtbSNS&p`_m)w7M}XX$i{;rX9J8Z$yL8p|g_FkCdm13C$9mjP2c25>dx6W(;(` zvEmQvsZq<*Pt(5D<`Lu($`sP?<0Xp1@fOOvd9vwtU*7T*y165IGK*MDsE&nyiF-M` z$$2Yy52H_Kd{w@p!e?V3yo_vIzls~yyzF+rId}Kx^}GB{==VOY>fIprxN*5wJCgm@ z@F$Y*;uqc^{_|Ss`j_+dv|c|@gx&_BfPjzV!_`I905a0I`SA&Z_{4XyT0|Oj42e~s0{&(3s0{9!Re;+i??aCFHGGldqKfdsQ;{SG5{q)}2t*T-Y z=t~A4)gwvu=ghA2(hBq6sa82f(uwxm%M~JeA3x?23$m;);#b6qvnlNw9lt0&y~MfH z@U6PWVz$3{@$e`vRa!Y#KX86)x4y$LPuD%ijT+fu;Zg7TN?vd=Oy8U*`KCzYzH2fO zvCW{bjcPsgIK^Htm*I|9{YA3Kr(AmJuU1T#Ll8jC#Fii0QnpHnCDct=1i@7-Y3z2b zPI+9>OV4hj>oL!p(1nBaz`VnbWQiu_Q8%yiV|C@BWcM#028yUOQ<)PI?93!N^Oa%0 zhZrjc{%^Kqy~zw@(n#=xBQo2Fx^PR%Y(j1^Qq&5-=J(SL^a{8Mm z7fx#WQ?NE87z5gU+=9;^p0Tr<=R9vp7yHAGSxq&wT^$v^fsgeSleR9x$p?c@KAi97 zbVElXNQW-bImm=FIR0(=YSlHR9^-s`gk0O)YbH`#)B{hNhr3vM0C)_KRA=|}8v4qD z0tQO7Xh*IgDNw{kpr+iBZ}~XN*iD4{t%}-T!?}l3#0;h;^ku_vT3lIrLbnuQ(W8`h zU)Tf4i#4#GV#sUMM&9chP!PH7E#7f)SI(GhIywcd22w47=y;-^22wj)UJ7W+x)@&^ zbGwFn((QbwzdOpnbXB1Cj2eY&iFwdr^8$Nuna&GNou z*EkAkn`%~l<_Mo>5b7#jrvP@AF&fRayW)-!aK_zF#lOKnt(1IlIQ3ol3qoO8BI`Z{ zd^&Hb?l&93;#Rq!vb>~AX9Ub>w_ER$LB+m{74c@XZo~pZ1ag{TMF(FNn&i0@{ggHE zE69^o=uTR8@w$&lF%g~ua6`rd7$~0YPh=iqCw)m}q#n4c5+@;wsR?QK_HN;-C>6)KO`hIdu9J228p}btXj+>f7Wn6P+Pso8<tp7G=j(%!U91vHm?|c(kI`U9!j7>Ey4C}RY+J`A;kb5g9lUGT|zK| z{zOJJ%_kQ-eHxeI00R}ixr|RUwjG))n8IF@*5y_DG7lN>Iy=ubHz+yBPU0v`L9P*AHBUW z2L&Ge%Wi#!u(t3-=c7$5$nPtf zX*=mAxHfogIX3GQm=3ibKI^AiJfAw+Ii6Mq5o@XPY}J1Hwpu6Y)Ez zuhIZTcFQ0pE(bkyV>BAQtm`uuH8K`q@KEac0-=8zI%1Je4wPP^$e8-0pNgJ&icwuL znQiHpQ)eZY$7~JYL~Vr|P*ujxBLxc=s`Pk8jgJ<(NIiSL1k|4%@PT(#U`B3noOrTz zf~YiXOtELNx_{7BuKAseLjIC~y;dFB)Dj28b=59;@E>1;`kpD8A+i_YdoJfW_sk0D z35MV~aX*db`4Z`K4pR?Ohb;O7k%WCCifXmLP31L7N1ip`RrqC?O!}g<-A+L%Dzu%E zi+_h~R%|5}-(fT2;`Ym@`qH8yX#7%A`v{QQ8MFbe_Q%N1ODKDCKH< zP2@FHbJmP$(CLaW-SpGe%b=>zCVadLG=!ku?pNhO6Bf|)fx=SYdwcPmdrq7nD4}fm zh&;Igr{szVY`lokg3&)R89sl)A4=_5RMW2Mso!l7Y}wA&_>Rn(3pSNRAiA0Xm+_3#X{Q`^qB8v!Lakdkm zre(6TVLL&21fc^zbQUE~_$+x^Lr&J>M-17%ij$uzO~n0lw3WF*=_K&Ml_jnEE$#LMXsM!C>ehd2giL+V@5H0L%e(0V;8bvrPxd@Ddt~3? zy0w#&^7wGzSpnrw8^$Y!_~lc+C8{H3yX9rGtM!<-l>6M|*4!U8;Ln1vFqxqB^S}fC z8(sN7W#)gXK#|kZJR`RTaY|Fgdw09>UOyk99C9WQCTJ6xduc?*f5m8&ea34(m~0?9 zlSF*pM$3_U+LhfDqxI`Rtpy)d+jE#ZuQ(3FxO6eCR$pXTPqpSm=R|ZkkE{7AYgI96 z+PFWN@)}Q)Bhyv^uOtOKuw-$UHID6~Pp_()7HsM&0sNhYeJN||jMQaZ?U%j-HrQmJ zkzc;InzfOza3d1w(b1r@!FcRWxw?SjskDYP;U|eYt$KGArShnlw@LwE4@-^;QkfmtbU?o%Ex*A40E*EGV;$ zp!H;X#G0;3lJ9VTKH1eCM}6|3FD@xKmH`vVM<5hcJWE1%4JJwbRlpPao`!w#YjZ;b zUP5J#Kba1k|I%%`g8Qh7-&ayhM_QoTll-M#wUdlXr{7T@@6{$@+5;-9r`>@$QiU~X zWS9@5Sd9!@=7>ihVYvp~l&saylcIA7%IWd(D|uR~D4i#F0zu|JuW1&yI6nf;)!=K? z`wYvF@2D?+cVffr4GmxK2JBvTXQ~rh%847j=r^pl)zNQ5Vp?KSsqPH;OOCTB{C|F0 zMdauY`!I&GeJYh`t~J~m@b4b#zkwlg|E>{vT31&ZP`9ywx328P>NAeb)q7HnqcaVj z|IPn$;Qk#18#9gj_O^u}H~FI-4u9fo!Shc8$h9kdWhA++JIqx|vT`L7Dcbl_WmPJd zVt434SzZrQJEB;D*@&tKG8R|m#)46zz49Mb43#UOCcQ&`GWmE2w=fH_)Aui{MA(4P z9^@2d!1es%WzE<119TpSfoVinJ{7fX1{3+~EMEU#k&A6tB3F&C4_fmeQyAyEALF9^ zHa@(kBKSvo1ujwH!#sSRC<)=Ah8|+HpC+}Rlw>S++wg5>Wmf8dr&0HW2K?MVVyK-w z+>3sm=OyOJ@U9_(I?fbdDQQ6=rd-~Mvkd91-p_Y88$P2yn( zTZ-dg?*nHMzlieVCYW~w?KE;u^)7-M*(!C0L;-cjBX1X>W(|G6#qOVU&G_GowhN3h zYQm5utAE7UOP9*IW-pwH8ZoV{3<4q=JBn*mSU_z=jM7x-+TBegY?cC;`5DpK%LA=^ zk&zQ0T$B-N z6joExPX!o~nRVq4tUNp_uGX&q>zf)d@Klt=)8DM!^wQ`MEE- zZ5V<~9yxZII*l+__5?5b=4G8`BV0bJ$h9=MExbE^A!Rk2Y_)jQ4CAr|HF*N;sCK<} zk&>k%4`M%p3H#mc&=8>_%W3T~X|IWp;O5&9Q88%g9Mpa_5<^PTKMbX8rV?++X}kO= zg!%q#pUz*$7RFpkleFDX0NS!g8Pp{#z>R|y1a49U(e{Q=IS_#gKa=T?9{^a z8}V#rZ;pgM+OSjUc4!f6GIO5C7rb9H!lgaL--;ywp?XHCGx`$m13RQlcFy2+S3#OjP_#(ervZ`hP7w^DM_s z@=*1-gp4~|?5qc5_3&DfNy5m~I&BprX0I20`3Gm#EJ*E4 z-*+mE9wR0|S9<2-`e1$J`;E}g#Tm24Ro3fZsOYh{QS1A3*@e9WRQDR(lGV@30i(Q{ z%Z{B*^zdAKSV*?3wJrMzA&GA(LKo+wH{a%Z-&qhtH)bX_4_L+DxNvr`89(y-IP4%jF zt76(GU|rm{rR92-@uBG@J8$($ z`4iWJ{)@Y9oByj}i%Iouu9v}Ojo*YiHIS)J7vwu3F5<5%+nvbejPgI_Oz({?Cquip z9FP)zUeWe-si&`Sv{5+io4V8j+cV-XYC&KmKiQFiv}ZgdjV^UJTI1D9Q6HISjEq+xC*fhdD0z;hOwO zxpsq&6mc!vhku~`p7q#Yy=!+ry=I^D9{xzgDwkE4lfKG7({# zMJ=parkj8g#rH~A+j%+Tvua&UY)RmK!nL+sq=;JeH+yjn`Q>%`8Vyk8KdVsi>(wUIGpa3k!e)paTHt zRJd5URR5np??1s%;lOCXzCwbNfI*;wL!yGe4}uYb$_Wkr&j;+^A8-grPt}I8x#IK1}7=3o;fX zryzI$7WQWxTyhFZDry>5Hg*n9E^ZM~F>wh=DQOi|HFXV5Ep1~HQ!{f5ODks=S2uSL zPp{yR@1bGg5s`^W$tkI6KhiS_3X6(MO3TVC8XB9LTUy)NJNgF(hlWQ+$HwR87k)1; zEw2E#ws&^-_7DCZ9$j8t-`w8aKRiBt$OR4t`7g0R=YI+IKjcCM$prxg1qlWFAs0A= zJLrH!g@Pt#fBr2Ey*rReJRUX9u%Ypp9@_#pArvJ+>7a+r=XJ%$%`*s$M-+%7u@IQ_3|IpMVoD4zxM*Kh7+Q0XMVEWgF5DfoiZ2!4u#DCitXqSlOiH3eg z1d%}2bB(0CxA$886-4OJB9}fix(@C_C4!Z-2?bKU{G?@(+|kt<(vF0J!+0~*@bdCm z$Yd$%ViKt_m!nmnKqX<;_W10j(YA0Yv21-cyd+d76w^@EV3f zQ^m9w9n{9M{RBgCB0=o{Us|c}GAZ{xCKyM_kGAw9v&)=|y=S-|X}`b}@L|~7zNen^ z0cANBl|X!;=RKiWRey(y6E@jyQih8b80Q>fa*}UEUs%ddc#?c%ceN@yte2lhpcnSq zufmp)7|2+Ma-Q~N{xN48)%x4;hzBR!$thQQ_^O=EqrjKF$oE=~nD~Ea`hT3(KOUEr zmFa)^>Hj)(CVJNY?)z`He_Xfd4?5&uD3E_P7@-*c15o)V#quKpKNtKLLFETU;`HN- zZT`vn(`a<+at0R0hW|eEU(T8dis2tm{_o5pZcbuKPXBoAe~J*WvNO}Chx(Ud{IFU6 zm1KbWmpuG8$qe7TRyLD4Y$8>RBYUQxDHu`&M_0EO`%N&Ht8{%g|Yi0C*V-(}RTj2j~)IaXve>du1R{r-U{hyx6|8iUy|GyIGmgTK9X%k3~ z(Di`ITM`Okm6#e7SD~ED{+Peu1F$9)pd4j9=HFYpPC_A3q#j`GYHBjNfOOrMI<>?D z&8J=IaH^q#b(Lo+CksfmFQ4fLU1GT`_p3)mX$}AJ(CgI0j8UmsAffJdlC5(G*6OZm z6DcR)cF^3enWJuamiNWor{NhLlb$CsfY}(6lI~G;?WF(^jnF#Xe7E*P|BT~r*iAnW za+)J!p(boL4NBVpHoiEauVPBqdm@;kF)ibE*H1p0UcXIoKeS~KCPa4mTS1$&`kiJYjB^9h(< z7NWC+N;BS^-RM+Hq-_LE-X48w2-W=jw5r-Sub?(8#M;>qM4E8uqNYNFH@EAEw#@A@ zRBVZct(lV;Dh8Wt4Gr`87g8O=s@S%gZd?V5NtO_O|xZ z`EqB42|osbktl4K&_vO2WWgq(?;DcEwU9VAhx1MoM=f3THqfhtUuEgE4kydX?Q$l> z`Z*$IkNJC2cw(%(&O1~- zm@A(Xj-V@>ko$X9xYTibh7~MCH(K!XygPKslvEebDF=BagoY0pwy=O-+${vuy#xwB z`2$6tDYM+{n3rhnVwM-6E8B(Noa`fUj(?8uY(Dy}M4Kf#r(eb#UL4+2pgoO{Yksr1 z@Il_O2{Se?(9&*jU^=OCKnoyFOa9=T3rq(hjxaCwm{&|n;qn%i-KvF=Q2VI?Gs?6` z>z_%c)x5h?U>}n67+Nv4LiaWTCRRtSrDd470SWUkcJ7?}0>B%gA+@ZPVIdEK9!BBf>iT=gXJUM#+x`X!t zCEO`y6F7lNhzg;_c{C>kwLK|uX^h-oNgPpLw=pEQOVeiX8}PX%kJKLMTEU6wIO&Px zVG;Ai^f{!DVp{{nT53g={sutuEtWJVPACdml25s5AI)knYf58bD_b?Snt9d_NDex>J);3Gm$j5k7i3(1_BKFK-y z_*vZF8>dTAg=mAwClxD$C;6A+t5(|F1QrwQvku6^m#qU&)CB-X;HK{c_>lmu4s>%L zd6#o|3YbP9USmrjR@KxxO;~UL4bQTE&FFr^j>TrNe7A?+lT>FsM0fbj< zT1YC+T_+se#V@qa$?VfTtI}8}dCe zxfUxIC)SqV2ReJ-9b8X7vK{I(IX%jj=Uv1nc(1Vtws+9~$`A)lWKZ#(Gl?Gq14es3 z;~k`0pmA7x3Q~38e7fq!$OpR}GI7iJNa2B@Dw1#4tWNB4M`Z-oo7495!v9%F-yxY5 z%F!RFy3c77`3ma}t{oV2%jl8o`O{;9K2xxUa#2DK&gy_{_$}u?r=XgYF(&{1k`uzl zjYj%ph_jh&Q{)AEbBLon5Gc6{UG_#e>>J*~FM|X)`3>jA07X26_PPb^5C9IB0chz3 zBVNx}Kx$8HF1iT91y&x2@K@7=9W&J(!+J(qNX_(GIGZ20YMMHb>{=MFaC2U^i_@DY zBeX&Ps4JEZ-C_(epL8+2itHv1(P%OF0eVSx58xe$w>MbOcHg`y=Pxv-e+3g3qb;B- zx?j+CRE=P2{sH|&?LI7U8=7(ouC$A(MVmYQ4*Xo*6j3@hb>Uh1x*3EA7!NU|SwT_& zhs}_0$em7XN!)(A$X3!LOy1xXKPYOMGwfJy=5kM=s7{sYOHWya+Ux}GHr2$WWTy+o z*;YBbin4lKuJqxsqv3qkWtBKP3e2mwPVDB9GPX#4FK`A)5`A^lG5b6&jEV=jBU z{b^nO*GI5n{f6vtblXkQt4cKaPU(ESfsDNFbB1WWZQ+m(a<+h2piTHSt&K)+zJ`F> za%t7i0uizpyv;G0s^S~FlD+OP%&%1(_$<1CL2d8R8f#ycsp=cNG{oZJr;AlltX4Rz z&2NjSlIG@Hn6^}6i}TghzDhO#MDUoffz7e=jrwRemM%s~F+s87pYPROeAVT-zpR-@ zKZZ(cb2Z?VkTrM1Lg=EeAcL2H(+$a&_&R(J*W9Y=Ik9WA*XSBH16Sa6JJ`xZ1JQNP zl)JR-ErPbKT~EW7S8&r8T%z@>IEKOCYzwJNcz(v$VdgQiQG#=#q6gDF=8Poan|u`E z&A@z{F|IALV3-&S=eJsYx)%~x7C*3pgOFlt{=xx*56O8s>rV@Ji%OI45=1ESnj5mT z+3ZD^D;p{x8Y7m|O;J47vr(xzQR-|-|G9IQL}A3o!-Q(eR^My{8N-X5hwLfV5K$Az zuh(Bwuf`Tp3x2_GCRW)LKR-)GJBV;a>0Z$LZQj|<3zxMORMviO%}lcS&4H5J`S&p= zdhNxQR&Krq9vcyl-D+#i?C;YPr!{y?S6-3!&Af}uDCV~JsDpliA@NjwWna?4Cg$RR zx+PJrpD&@|`S^4KWLv`Wo|#OcA|Jn$<&pX&UspMYW5UjG!$a%Yw$){NN8Q#%z&+@8 zW6eq4U>+rS4E9g;z4}@7SgwLo&wtr*V{)*fU2TSm_D@Ln^&U%0!9z9gi=v9Jq_P5h zFIGJM0fimp)PeR*;^IPG^~=-h?IFdj*3j+sF>ox|uW(ZiujS*{!KbUrUG3e`%Q~Lw z*BXfafh72=F51hfqkufE*y?qeGc^Pg$ez$T^*71~hetQkknGsY865epcGP9+0 z+OL+6I~P+)NF2f7Ba=cyWw3b-S)>@4tmLc`n1noG?>1zAWT~vpBX0g>v=>(I35u&< zrA(pH+^pT%#$lSCctAPPmFx01@3WuCjn#DIYf&!uL+!n0Ff>2&)L={6B!Ppz!hnT6 zcIsdZpF0czTa1`^DiBP-XQvv_ZU2t*2w9o&RCtxvjJL*7ZKw-9QE;ov7SDS^Kzt>j zb}&B(6JJG~R0X*YQUh*;x)-&^_Q&5lb=jw3 z+2V=`s$RK8i`uVInRVUDuQ5rlg=r-bvY9 zl0EWQd0_}mxIdI#nC8Ukxd8&jO?>3auol$;<+o}ErpBwZZE(xQbuJ8Ot`P;w^VM1} zj1Dblf7nF7-v<*!KA*Z%k3OFh)X4bhD*FTz{h-!~?EoawXEVq_rKi(xGEZ-dP0I2i zD)WW?GL$?i3hq)ZTqs-oHfdrS`h!VJ6A5%HO+)~xH;Ie}k(K$ZJ&nl?hALwP^3lor?0G63~KtxsD*vviimyZZ9gDckeiAOwR z%F$9Yqb8o`INvXMZq2L#{<#z7QTKdLEy|Lfc-~l^=^NFlb5dTKZn9pQ9_x>i!W~PR zmN@3Q#@SZAGtSu`CaX*}u|jyEav@c-8aWP$bw~;gNme4CIPBv{r*6oHCTaWfEmdj4 zcgx~q$l*cK@d%;e6e#;v>yk|C^{!#aT$%QIcsPIs+B=cqDkVuqQEORmDGEQ++f{-U zwRKK?fz#ALSNh}Y09dNxsz^361M#VJDe-Yu8>y4x=orC`B( zb33opUTfX(-p`=R#du$)<6Mxs6jRwb9V736ccl)TZHQ~+`#~QxOzS%vO*-VWaSfQFiv|g5ZJEG=FgS>yc5Dztx zSHEmw`;3u7E5tBl4_h=oAcedbaF#0H^F^Br4~0mr5&<65zD&GCMzJ5qfwvZ-pwiSZ?&VwPGM;UfZEX;Jc&S~~qea4!geQBeH&Rm_! zs_$P!`S{Z+rIXF#d@qBwnA`p31y&1VrL*a_*kGOWF`j@$y4_@<;cbc`3lM%DKZi@=4)>N_F4i+F?_h`y}8fx{8ntvLQH; zKX=Wl)}br(z-=ZZYG#MiY>8R4lk+qMb|;Fh?(G{Y z3G}BEo;K7d@+zcl$jho^*EpyyNYqsQ-sBh#K&A31mLc!AckrV>pQs5ye zCe^bWhdx7Jo)?=dy}>dju6;Tvf2Op9Xj207bb1&ZIs!r z!Ypq{l71jC`ABOR z&OCusalpHfPAsnA1GICfaeyist@vH%tmH7w4ukbMqe<@hDux1XGP z*)dj&Iy-)G_M=@4a3HYv9>j>kQ)Q;znWno-kB1=)SWvj#(Orx0OK~c$O?rz3T5@0c zsBGjHwj52~lQ{&-y!U)^Md-0|x!;SEu9qC5ZkzlrH9=z z{{m0O{sn&A;f4Sotiq-IR4{sXppX+J?S;zpSQ9@IWe8lu$)L^=9YB&VK(GS;k`TjAA$^DIO;Vh<97qYlTrDYU3(CJD zC5S@`B$DXH8Ri}BIvS(a$~uTTNLopW5yQkwf98$K7`GS7J{RBRHu^fhe_%g-;vhScsBd zoKU)D9^$w7B^oJcuaIXD+7E`_i@n0T*$~CPnMA2f%3DT0`X_Ii4XbLZVlNS7Q z6MjwHJ?rX4$SqI91NHR+NL!PaWwHAMEvw`sFOOCxA>Z0o&v7INEH||{`W;D^1={IW zVsjetIhj@G9o5Z&v1*FKSWSVYHio;}k)GkFRgpjNX!|z10c!j6{Y4)&W}g{mca4%d zjzkih_!b`Nd2zew-rVW`+t^OZrbvp-trVCm@j`TPTLQeg;{pAo%jSRljr}?q-XM?o zGdO-nuM~_H$sTcH`KRze=>U4Va_Udi9LoUrgW#m3ezC~ZOL{xMk3tLGq}5u8HDea! zH$C!_Nh?)nSy>%|NJE4F0>_>eK5|$Mb0?9WI0wIMB`ehyqIiUAg^`GSgGGR()x5}X zGx>r$icFIl;Q{u%P`e~{I(<&cPso#VTr4F?Sb^})P$`+MQ-qr;m8%Ph$me@3u`;nX znrw1*Ng9tXyT-{efeIw(Fm)>lXGjYP%cv!1`ddf4Z0VIZ{Zd(JQ5@k#PX)gq&3<*$ z)eaA^&;q5*XzJmRPJegJRQj(E9WM8G+_?T!vw@|$8i440doWP~4#qPXu zzB$Uw-AnLAqND>;rIUXLIZ+vmA*^QD>Za5Zz2~C4B?|C~)p`oQb%4Vq$dcRas5!fG8(svi zV069Su~Ekdzp|RH2E`pZl5$oeo!H28)}FKA$e(#Hufp$iFV_R!y02Wc*H6FKp?f_Z z3-26xH6HdtgPgvOFCSjd`Fx|qu6Um{ORHUNoY66uU?XTM78om#QoR%e7H6A;dJXF!htg)mPzSdqvv zfAR3T#VZTr_>JMRP5_xk$Z(mM&5%@Z7f1N2<$pmw`r+-9J9$CScN4B&e!X!FoVpw9 z>*lo;-5iMq*K-{+Sh?6~e}n_?e8qRV+6c~wx~jmjHX}hrW)qqN8@oU7qM1M?iiKmd*K1G(Oi1jWCS0sLz$_dhJ8&d`)AM zYzT@Y{#vm#>TUqKZ>z0$1109uuFmPJ^k!=n1Q)pM zN#8yC2CkPy_padLKj4e%HA$QG!p1`Pd_OeIrB(&^!{>FMFU zWyIl248|SEHGx=?Ph^oDVa)W3Lr&x-oJ}?TA=X0!q*MP6NXv%BG!-z~udY^d+T4B8 z9VupH5?-8(V%F>ySQeG7(~WoC7QlP1O`sB+#F4CTeszS@LFW|{9Y~})^yEV7-~jgG zu3h{yUAKaj#8%wcQsF~g(h_XezIulN*I#*@ z>C;qm`wrB>lct&EK&i78QE+Maw`_heo}Ryl3VJ;!Cz`vI&}rz%m+swumRU43b0aJi z42!T9b5mNOxIoq=^~rYQSc^o5(E+wvk-r0E3kOg^C1W9ynK1xMuK3Ab*TOjXTE}J5 zN}C)qI@q4<)>xAj_n+)80;E`JUwwkAR9mMI!jjTkn|>;I@V)@CfciTuy;B)R zp$ezZRmOmP%EW2+&Kyz~$}J;GgswPdwj8n+8D}O9;YmH(~GjB=uw%4woQLZAiW+}3qREYOgbXEPi93ml`}#cZejETQEwwcXj~-bC+oSDSydYMUy+YFW}>pH}UXW)J0? z>92Z8Px~m5=V#a2rc@Wx(;iv*O0gTo<~{I&H|M%JOGND%J=wQT=TU~ZIQE$AmpyJT z5p>-;Ai2JQ#k2X&1;gM#)81~wzB$i!M^lT})fv^>(HT+a%_Q64Wn&>gC&lTUKUqfv zz{&nOR7q6aUBel$Ou?Uk7PJJ3lk!~X!lN&Z^x6sL_+-#CjHZ)Cr%^tVSHkWyC=ns_ zpU|LI(s^rX{Phg}1|K5+1ysK(Ym9uXG(a~elON#Lyku%&7N}JRn+!BTnYo!b8vDn{ z>eG;KSuGP$b~wzQvzv8GbtP&CSNR-$w~w^difz&0Hnf$VXWcosH{L-)*{Jb)7c2{r zEWD1DzNPB=I^xjb^Sr!;Xw5IDNYtrA3;=`SxW^U{q^v_GtymBu{e@wOrFm*2XF??* zr5L2h>=R#uw+M6O4%O_E<9oJ>7mhT(9N&+Y4`cT%itjWGp=?|>51cy4tsh{B_3PQ~MaZ__okI7?Sq#o&hq2u7-C8zxiuV)(Po~I+J zTeQ!(WFLSXWpo9Hvut@xPPaV8XeBSZQz5%1Wyj1xyhRZi@@ax^iKxh;!iz@PXDu5o`{B5T1`W{eXPTblu>q0X#pGqsqD29;H!oe- z2x%K}YSeTdvac)^GDc6?cgV^a2?l?o!Al~S+u1E?wCiY#_@2@sa#9lPV}RXIXXsa^ z<$=Ql33SO_%#rEP5GPJk@xrDOm4AXtwgy|oT9bC{!yrB@LXDU6mG>^(;H9F5+B|ynP2M?;8zaW@JFihCk7D14xb$U z%JvHGB)S@26nuGZP35dC%gjX0YaJ`^5l`To@D!#m4qV5{30)4p=2OUN^6af~D58>= z%g>uv@M}yi`+y&`A8al&J7`rcS`=SbPBimpb&Iq^r{-W>n`QNpk5eH$66gup)@v=xJuUk255#|o51S;U9owh(h~m{4f(xEy0^rW<1D#pEYH>l(eQG&S~7J}eH`@XUNzJW5BB>i6k_ zmD^_w2Sn+KMx1Rejp8ZYIdZFb7qB@0<{F%Lx$%-=S8K8MrhgL>+A#A8m_Q;m?`V`0 zaqnr#UCN}NKX~xq;q~Pn4^O~hW{sW@AgZyr_fPN4X*b&Ui_o}{9W?%0rP0k;#%Oy# zoxZ@us!024pB4U8!^z<4ahi!b+f3A;5?ki(@wkoX$L*v^N{y?%c2cFYXUv+wW3qRG>st&kUnKe=(C(#CZS?au<&&SuvGF&78E#)oVzz&>P>FIOl#8Aw02R{aycFdSYb-_25?tei6b!x}MA1 zoC|@v#Rq+5d2=?bGFxxNFBM74it{#?>yikd%|d;ymZgsf6?uP9B&h=!({$t?Xn3F_ zPYJf>_X-Qs?r)9WawA*I$8bwDzNUwFlFgN^-u=r!cs-l1y2$V^X;}O2@AC(`?vHnZ zidOJ|of$BR1JLU>-KUehu+*qO=RuqgIQ0bXnR4N(@!$GbusQ9=B}}p!0lXzs(v29f7_mUzvPORpY1GF%1i$ zG8K13@r-BaC~cecwa7z5#*oGxGWbN6wcuv%6q31ec*oRAAeFqu0`L42$jaSCjz)f4 zR_Y`I%)30`>Oa(sM?hrSqXic|>~wpd@~EU5)W00agLB(?_;9h^fxu>;l@e5Fo~YG# zYo^vSiLtBz2~m@mRiSY3C`RG<_tg_CX>KKhNA~jGy1!#xzI|w|#mst>-e~Zc3te;? zr8SS9W~TB)h<)H$4RE9ls{W4i;rj|i+IF_%?D3S@J%d;JdAD&f1BhLSgHR?G66hUc#4W} z-h9#8W_mfz;b@-0xR*fY5{7=>RGQlGphxfm!xLt3!ki?Jlgzg1-7|ALLQx;FTFw^F zE-NH;K;=T!I9n?C!X2i_Aejewjxulwj#qLDvWF^-P_62tP>+S4$hwPYh`7IAQUTpRQ zB>XJUw5^B8F_0Q?Y%lH=V>ShCjG&Q_v+r71n08Dr9%0hNv`n|*Ao4wFc8x(ld5L+M z>E-q{yjw9S*Q(Oeceal%CSB>ZGqN(;hcQ>QQq=dh7rN_w%N_2C@lCvW{1CQfTHjVi zu$PiXrBbkSEF-kxr#U5WneyQ+x`_>>PB*~-O1NDL0bfW!JRRiP8j=ra>B zv-m1=Io5gL66gk3MJ~xR-@Ud39Y{?u9o+u~QG$Pb}|@ygOptRuw!}k~~)TJdoDD;{H0> zMZ0t1KrgsL+|aXU@m)+6xf0MjYlH=Dh=WSr(t-%2xDXFn;GT59FwJb+rsaNp2t91h z$YZEIYMaeX<)u;8_fVOkUnvgIY7VZM073en9`8UU6_DMW9S67ot^9IrC`a#}X_jGY zq5i(=E3q_z?QWK2UwgmGcILXa%;s;Y?)Dk7W^3YwUC_sown5$25=8P5hDFlYgfe#( zmRvbNIDsL%&$3Tb5OS;7Xi$HYEG(Adu{;b_0sn6PebB0UhZr8LP`CiSSyIY~XzSP3iYL%sL z0v61!iEVm}-Qo*XDl%i0rNKEIAk^TW1{+j(hw`kFps>}QGD7yCp3{ zIi?L%8CDPFjJg!YO?PYH`aseZhrl&%cx2gbdFa|6D}1JpI67REM6VLTGqv18(4yI+ zicPhgVY`^%u#P-d($a+74SzSb;H9MxH)X8dBm>JYnq4OE8Wn z0DqExy9;tq8LB{7PKzkrb_8yj`j&bu(XycUTj4mm{?rd;17akr$94)NNQCQeCV!9$ zlB$F{iG&`EiwFxQkQ7vj_SJ+i@TRxWMhWQw!l40BK)YMmTL!sV^B0lJ_|kI(AW9#; z2phWDtK+vh<~K+w95GOI{rom{X>Wqf0ymz6D3vEWbs&1V->;Y;~erym(j<>B`YEmUdiIM@8~r1b>zyT zNt4)Elbag_N)Z}E5~@iTt*B`1bCU(N_MLt`34?i_DQTWQeYay-G%?|)m5{+98} zaDw0%03HAb;sZ%Y1WF~ELdFb=>I>9~hb%3}d?zk(=+Zmk@w0>GgsfVUz47wZxiM;9 zLFzhjpHYLrBz<85s`+ zkJCb@knH4Pg0;sKPZ8(|^ACb0Hp15!(ck^grmCUuj0wIAusufDW*r_Xu>G&FMgFaj zUj-Opll)m=13juGH`1065Y@dEfz&PcAii0AHs7It56W&+>^LPqOU$;gA;hM!(Pv}A zbSz#a)T2TJBD|RW*+^SR#RqQS*O9TZ*Tx%f1^j~H)U%R<;klg%v)^5T0f?x>^o#{S z(Da=#GXt&!A%9k$!_kW|_7y4kF<9c)Geiwug6_AV2dzU7Fh?B~(Lpd#kxUiH=u=ju zwWG8ws0Bys*`)R41H?D5dwOL-cCsV8+=1{$(8E^8V+)Jmjtc9n#0uj^b`s{=AZIgm zl`2(){u0BK+F+qbOEnvjdLE%|pA$5BS9mZ^c1gaH%U*D}*!xy`_h$43EQmopReZNr z@PKL9_NVMzK9KdOnxS`D?+Q6HdhOyXH9Kn+S-XzQ^?r-eP@Q_Sur;)%@^(A^qkFRT zdj1uA6&-uSb5m1NpGc=WpD4G(<2ImT`16a;-3&P&CQrfHPS>XIj7yCMqtg^F&E0vL zb2&Ydw_BTK8oPqbOXZB@*fzOsvi3!r_)-r>bak8ZfE$GvuozHAKwgXcH6C7c^(FPk zb)(E}3%0Zsn=1Cs9%OKJ9Y2kEKobS6G#YFDmZswAloomhgK=Y;$HjZXrUFOfd{5aZ z4(h3t7J04AGg@e-RT_+M4KVPsHOdkf=RG68DQGKLW-v`4FFzt<9lQf9yaNKf16*A} zuGT%L^HFo3g{%~=w8a^hW9C|^RbrKzEwJwZvXnkm*QKSSB9 z)_U)PMe0l=iDO}sGn6a+%Q$V$!MUG_Ips*kQTd_QoA$#SSF2rWz1C%fY^c+ib+?E= z=w|J&6Ijabd}jg;$EU2w8?WxDa;nE| z$FHQ|HN!fM=>m#=-uIw3XW`tC^V5jR6@58J~^B@l3&>$9%z!Ou(Mqh=3R~7n&5uo!4 zzARU!fKD+B@)=72E@<~ajD0HLrm8xi-vWc@1Q-_5$eWF(h=LH0D}yShxLMT9s9G1a zKX?%$)0T4Jc`R2i*E>C~aQ}?FwbID`EquC7W!ty1^89LE z{rmXGvR2S^wZ!z6VGjy@VkC1(M7224A)Zlwzp^k|Apx1DX{1rIyI8}!;aG%CREFn^ zXE61R4th1CrY`0s7Jim_x?x8C_|-I6tH9Q62f7R0gXp8}#O=A>jPyM6fr8|yde4)R z<u5$pwlbbo#vuC|7qgMCYV&&ojHK&-}A6X>Jsdpl57qP6Om=Tl4aeOWqq6J!n<;7+6A06yblpsBSzCt^vZ}W8SZUt;2K$ZoYU-V z%Rp4&Qh7Qk7kQ{x+J(9(30tv73t5FL#7~ zSFmo*m^sDTx@aYo1*x?rDBRGlR|uKuP-t^W+Rkaek$7=Jq-2dd@KVv3O^KRb*VXsE z5TInhMdws^ks!o}n#EbOl&xupO)h$-V(z7nPElLc+?eR+41Jo4)ufUST0$}xSlRlL zqu@Q9-2L%*=3GpcDQA<(mv=(+ma4o;XXcIk{yY0EY7;i;Y{37hF!-8>Wa;i05?TF{nr|T7T#^gB^ z=$jTea`zcppD_ZV7-iD`uH^3c?%r+E?ZRzSdQtj8+P%cM#JQxJVm;|5Z#3i{s`Ds1 z(UK-HF%cSV82yOl@Jq61*2A>FHuC1WLC9z6UGF<4pnDuWu?J`{aE#?2Q+|c<>da#V zn88RghfpY@pLs7QYjeyb<eQBqBYV?=sDYo+X>pK-HOXQfj9;@ z2I7IdW$j%FAbFv&_d$P-x%4W-Ti47TltJ-EKn_5e?B@S`Tu*M)W*EMv9&LMgtyFI4S)WWN41VVkL`vdHEOPGugTl1Ic_@+A@a7q{@Hf0vM=z;>UjHXg%>gO5w2^liJyT_H3B1|9|V zc-m!IG$B7O;VhfYG%y~aVouCwl74Q;@^14xxtR9)?Bz3wC(RwHs7=%AG+W&CcwZ|E zJhy-+?I^V*N?_G7b9rBctmwj0!p)aqj-4Amrdr)i0Fw*1ws(I!AUS_b;9RZ-{m!{h z!a7Oc;-zx=pw8>d;^$L7X8Kgj)i-+G_*B#-+$)7M3Ue>@IZl1owtA?!dQ8uyp!3ZX zJAXOdbbG&!$qlCZ+;Fw@H~PhAb7%N>4pUt%He{@8ZbjaZ9Qn@~0b$3>P|1SD zGF$crO8GA$$0e^Nuc@g_SP=^wrJ{GA#3`5bQ>^rFW+vOEDOT61z(916l}DT`yw2K(Lxszj|Z`6(hiVyXL zRgbQ<6{+3MM}a~pQi;RvPj&jOPZk}u=6zWNsLB z(-hL33U}$=RR8*PX|sX!>NT1ee3Vl6T`dOYmPR!@D;i7rlV8aOzJVeq$Pl6#6Q{`l zkqY7(`{WZp?dTtM5B4Q5h2#~&@~IH$C0$e{l}k2qP)+ly6>SM41JP|mbF4U~i*8S? z*E=W!Ksw@1wTC+#27*Z5RIs5r(TXXEW3LD*>&Z?emkHwf(Z-7%%onXNuW+NvkuXU! zO-ao)ziKutH4S-HccImiShi&rJC{5i@8)?+m`PLMYvO<)FlX$KfpSi8C+u9(9|0Hj~WCZU{Z_PR5jNFfvKU+Gxskl!1+)sn7Wa0>_h!kp(q^%~ z9ve2wZoxv8fQ#uw2RLhB+(RZJ$DeMsbD$Rdk!N~Pk`!1hr)&fMa^KCOetTw z)e#HAh4RwPMlPqmjdHn6;8-r<3?tH7yZp6~*)#69Bn_xTL>fGqksItb!MFUy` z^a7w;UAdxtk-eX>pR=z`N!vo>taH*n)49>I)860Q@7;KU>t1^F>&W~>U>5zpH0Q>u zdICO2ns~FADSp;EJjh*)Z!n zbWO2iW@ct)rr3^|nVFfH?Zl3mnVFfHnVFfHc{+RlGjryiduQEREtN_oeNsuPUR~AF z`|utSRhq1uc7Yz{e(I;_7|bi=Ez&I#s@N2>8j`A*)ied+C6g>j?;JD9i3qf-ym}&Z z%ienm10T%a5PSqTBj@-z1n5%;C?gp&jA|7=wkh4St1hl4lv2&*_(wb$ns{&BSq+oA*q=eAsb2Mt%bJ?bbuowLi&ML%F z#~a1Yc^e(tvV^D1-tdr&Jq4qxGh^A1jBVS$?nZn#h|S2JX-zznj2(qobuFK{*LJJT zf2{}KF&j0w3iQ?WtIXNC52q<6J!F!PbA!*RuwJDm%YAk-YcXS5CaEPsZ60zvGRBA} z%&=0;iN^=OD{E5v2#*AGzhZ1QD=u6kjh5q=@IJ3YtFldlU5EWa&EiZ&)@+u3A-^Q% z`mNC9?pMlo@#`cCTaB*i=BK<5;)E{ts^wRTcEucuQElxPAy6InvbVz3z23a4ps7nn zi;SR$_gYqk%6!?et>xMD&(DC$eE=M>yFdZ0RRtqraJRt>$4!2W!cP6?L6!h?_JyHb zC&q2yy^a|_)HbLemEQ~yVTTPu_YXhBdKhc?Y0n+ZyJJVSPJ{^NgY1^xV=3x(+qi3{ zR46l-^=bX<1H zrd)yv;2*{GK}YMkQS4csc+Pw1CZE8DCBddT;ML%pj3~EQe*V*guaa^?JDoffRbV_|9|(l5K>a{a!^p+4Zqv-sN!`ywW_#dn zn9@atHrUAbx>s{if@gc(k3d&mXaq7;%=p4KHcL&B)wQ5nqX{|S!=}JNz#%bxW^lfo zy#iGey+gn9fYQZI?)kb!-i=r=MloJFh0KIB%`o*40)tajKKbesxPv{q8_xQr86rmn zt}20wmp%^Tb*GQv~$j1J2+w8h4N9!OugT zbGP$&$GHk;ZS$=aFEAbAdX^8P`&kO2`4NeS{5t`N5wH{ZGlm{_H;rE=K!ONrd*l1& z7^Ba(%$)X-XHYj`H%znd-iAD75j#F?S*xilh5X^20<0O*dAp zg%nd5oO@U5f3CU;_wsob68l#7x_%^0^HsBcZ#Uj|+7Qe-@mK~kaKzo$^ovqe$Sn9^ zc91|HwfsX?O?Ka2x}nxq%?iR9DyMbykRWO^olBb)>C(-7Hc%!|Xe9#rEMf-a3eg(I z|NGTGxVUcm+o>MUbYqOy?~pVNIn40`1Pa^B=q+JE59lcr4ZZXzP~I(}F5uM3qb;JA zcTF_}yeTg_g7>g%j+U4m{!wlao0-a)bfGuPFYAvl@Zq)IJpyL@DKEnIgpqEy%X~|r zS0q>9Us#;;SHNaK6ifxzyoKt`>R7=jh&hTtj_#-rC`$7gw*D_S0qLO)Ko5CAzS&33 znllscnklY!pcVwSIjphh(j;jDg*|Pnxr;Hks#QpM+R zw(_j?>`ycu(-n$|)qSQ6Xq&azp&jBH5u9DAc*BD5=4{nHdt|?$)dLSA?l1Ow1bO(~ z#FJUkAwbD>2sjhDvQMX+=4#GJT-xf*uz2oziA+r^nt<6 zs3oIwjaHxLA-UMwUvBKf`WCY55sP<}^SbH5Ho>i$-=gELH#M_uP|LUIs$FbD-W;7; zBDOc!g?T%gQdgPgMYGRD6k;^3^jv@pyt7@P+OytptoP`I&-|k3**0LE)FQZmB^C1+ zv9rs-%f0g8uvzq+!xHx}in6iIcqxT_L)g_2M&)ol{B-?IuJhz6mrxfgk)3zHz$_mi zoam1Pke^j&I=q-y1;*O3QSBlaKFO+S%|8%X;uquig=E0jU7#z`R@mxtJSQZ%|Ed0jc!WWrcdtKR{DN z^A%%H?J$S2jSrW@eSz!V7UYUI*OrO|t?$D;*L9M_%|(ZftvRJclzuPy*5G zhze>&^2{3S@7(6_VPBrfq4id0KIZVi8Gk-VdcV2?`D(Au$XoM^8uAf5eL!44xwyw4 zw%R=FK3(`iGjxpspJI4-iRwgzM&ov6yl?}Qq1S%yR089H`8kkNw845{-GM!H89d+~ z!n=KwPjxN7QoH1+>QW-}oC$i=yB5UeU;pi)xLuaRD=<`hf;yiT@JyuTc{61E8KdiG zU;JxL=nae~`D?G11MEPNz6et)P(rJFUF#qJ`*llO}J+;>bHv4T;&mT z{7>Uw{J`ZIs=zwnZYN)7uKHl@P}xAgd?C2=vCfB^y7^iHGGA=IsXJ7HZR3^FfmkG;&b{!GzNk|WF z2im;-DUDLhdrQoJ|Kg}Bye`r{rZ|=e$Kf(N4+N{KdnZ=mGTgpn5$K*S^34pzdc#FZ= zkWum>VQp&BCDa))00Xv|B|kr9`Z4<;>bVjooAU;Y;#ZwH6~M+H{XUK__{ibc)djj` z2Jwwz4OrANniHgT47l}@GrDa{&moA-4`bAhK}slRw68~tzQc}RRKXsN+1f?Y5{*Os87zj zq|2qR=B1hI`_ciN%WwJ1@BFQsF;6g_WI;mK)A}19O#lQ{}=am6U3Z$JXJg|Y&CDLR4gz+Ti5wsX2q zvRh!r!TjT^kqi2_^f!)Todu7ab@2WsQgAl;A7Z3H9Y9(@ZNP^Zj?&&g-JZC?>ggUF zU+hz^p7}wXs)2nOs|4DOnCWJ>3yI*>|lIG|69hN5{zX;KD@iaU|GiGv3iSBGNwg|uI8itzuE3FOSr3-`_}vuldi7Ha-v_duoMPA)GvN+37+xGw2F(g-^4(rIl78FokPc)F~DXI z?!!1oVmS{mYFe~)8E#Av1iAt`hMs@myy+Q+Z3T$crS9G{_5A4d?9V`aNPx!qyA>#o zYHH3n^Aez*IJ>Lz4D)MRCcpQvA}>is#hh9f6v;*8gLG-gMO337*gN+4m;>iJyivUieQURIskojaxpRP9j# zntpPil*$pn>tpoJs#Y1QW5GG4n(+Z2WX2PQZvr(-dO_HxX@TC{fRHAexQ?utMt6fR z>V7=~vYS%oLOH%_3Bm`pJQq4%L)_L$O>zP&E^^N$K7PoF1bSe(nnTR)oRYoYJ6iqF z;`gcKTS3g`d-=-7=YLvv6Z3n0?mF4(1|$Ui3io-($++VJnLqkS^MSeiF2~<&$S=U> zkDulP=`}}A@KsiM#UJ|aFzo*}8wv#RMZ~O3|D|mEWC?Wb%bW(>tvY!>18k%d^4k4X z)N$uuncI5zYOd|UPfaVWHd+TArk+_oQMwy>Jan6( zC4mDRn||eIc}Kb^hlslXatkS1;~-s=93gSuHNP6KaA%ogU4uQL&X* z)2y`AHh3;wcCg^w$dc}}nf$QVL6(rOdigP3a zEl415iFCzq9;r3$$f)lw6Xp}A&1@^4^rjs(##nHVl1pq#by()R^C-}EC4QQ@tg99| zqxfaYyyUUP4&swkTH)Fbd(#CrRs0mQMIV`TMYcP6QNO$7e-+R`l`%IiZSkhn#MLGB zRyoTb`ZT@I$}wYGc}N?K6AJdaR+31+pp8D<+2C2pe@92zwU2|<)pxGQI*j!t(1Pm_ z(E-nWMcbYSzr74}z3g%?aH8-@*+fZcEu1K;b~~iCw7iH-0r6Vp6~jL!a<2A}xpM2( zg|120WK_LGe}q)ql_izMB5mn=oRra4cepOP=78`8%1XXO&Wrvh>s9ha>SxorTH_Y_ zu(QwOM9cRMKCLZbT#l>)m^WmY5C~q#y^b9(8tyOX8>>G1Zc~{4_hi9LTh(NE_4k=w zM$sjPMDR*{K&%>5v#)l+k-irCGZ}UJ5X$Ul*LMog+g;d4WAT%n&g+oP_!3Hbn3GWhe04Y(>(p{aR~d;b@&TGR#3uoJ>Jeh~IKCCo3?`M9ZUcqjBnU4JJ-t1HoU^>9y= zbh@E6h+;@Kd*sQYdzleK6g-RBPidVp@v82ZY;p#LS$+9q;&N;F*{gp03;@CH4yv4f z&MG6G5BKSfg}x#9&^^ST)+ChMz{Tup#%hUaw4YfTRCQau%OV1QIg$%bPUDbx`3ed2 z1LWMd6f+ z+4jY|srZ3$Ios_?YgE>`{S0zc*###m$3sqotP9@R4BO*KN4-`Sdx>@otv~0Gc^93h zl2y{ZLTJ9WE0O2%%tFiJ*6V>EB(v^azU!d~cl_(|H4&p4?m^^CB#`cq&l+Y%8zuoBDtr}s$ zIvR2XTS0(AKg;@d+DukcsdYMGXIy+v2Q_yYKE_8h}I(`yKw zvYnfH9K|~R{<&J)SE+`$TH1a0pH^d6dBr!1TwQ_hyvdR)l8M$Wu6{g5W0EVui;SKX zhII2yi#F3qFmc|uk}Uq-o}NS6Kj!jn@3|e?neb1n0`FH|9=+PUWZXn8taY(D=(Znl zd>S89q##Rhr zbD5Ti@0&lRmxUCU$;BU^=VYde=>cUFJpMN4u%qwz-UDO8OQ6|8uJtboun_Dd$v+O ztRuv7%a@y;^oG$8GVEI))jrR6oAzt4(d(NFRhUl|20!jVO|qE+=*+&Vo{hmg^8{WmIa+(! zd@@tg@t8aLI@6h~xg@+y>hRlE157XkTJVqlGTXVLyzn=e-`V{!mLa%$Ah!vHhD(5} z@y@_0ID;Y|8xhceLkkiwgS!An&F>{Y=pi42;Sa}G z^D*6Yv+R2t8I+-1oY_#sv{5hE^s(>Tylsgjw7bN_Ca6+3s7qL%I-z>I2gmM_QKl#> zAsMEmsyJ}OR+#RQ)y3STAz18gQEJZ(XK%cI4+G8dEKjb^T+qRwSy?D)`#Bd}DUXbK zsIM)VZ=Co;A(*n2G92cH{H9X^@u~@af@TF(oxn%g( z8yN>;qnmhg`worIHjJH8`#mg7kkbsTO@RCuT<3;52y+*Zrs{ZXpaZdzl{JRQ;@~jjqNs|CxWYS;y$gti5`rui>uEG!;^8!l zw^goLY0SJMQr!}k_|4j*!0<9|3XzNg&iDO&l1i^wkZiPaAa-E*Mj691mVxdMJ@#Lq z^GrdQ1O%(yqoSCQ|cm$6}*?Y$&VUF zyGLv9nh{N94gGKh@LmRZZwDaotsSGo}oXu*`>z55^ACu867Y&-sYt|Z;t!Nvy!h`FRb;A$LYpT(> z^Ctx4`|GbAk1i%*j?8N4XdlbbEEfy}X;tzEpc>Q!sA0)Pe0ERG6_mth@OF~mZc91M zyk|xDBv|mB`S7iU$b0v?<305$z70|x$h*@FHOkVcz9mgjspF^6M{ai$yTOPZ({A^I zas^m8fsu0tAmC8{idLvOG=rag?lHxnCr@ymmm>Y zCvTLe!|VS@TuGkasOZtH)l` z59pu#-X_Tt#vT>JWte>3V%xkU|HkHf#c%Y^6zyaVxrbu#h{quF+JrF_7RT4$jv!@$ zfD2sWG0orS)+kxJ85xuTxyObE@yGXC6?^H;Sg1}yZ8NOHuC<1vc|1AlPWe?8b4Mlp z&QJbQ*JfNrL-m>w;t%Dp&`)cbA136_2O`k7`;b-C(iWDzaqZmg51YZ8)E%Uo;+tMo zI};C)H*inJFHBy--bpVC-h5d@N4JppFCpX;%sU8Qy?b_1?>SI#Tih^68l69YH1 za?s+>Qod!gPwqRl7W*II6HS)Ctlcw1)NAZ?RdW>TLYTLpwa1TTm$0!v2&S3_xX zA0Qq$SH+aWId!rWF(Vd98VVGmKvqMig(F_%PG6=O>+rsLxl7m<%~JQ?Oj8>Hcf_Mm z#d?TPQYuUp)NFn(_Lh!*1P3da8u@7y$eaxe_-0fR9iyO%XPjd=i)U1xW1RIL{Y34; zr*iix$&DcEQo6CsD;S;SOE8c$maG~@b0^h4o4cXTsAq1JS;agK>O9wV6rsp(ND@(H zS7Jq1s-P&g0H@y7wAG12j;eqHd#s9d5F)6AP%YQ!2ss!%bgR#9Vd`wnW~%Ng*N@%@ zm!YFBUYCVcmOOo84@#p(Hr}Z%4M<3sc|2V`DLB%$l)P#vT!M3`)!Qr$)As8u0 z*P2+OIXC z1-@)}43RUO*ttVBX3xY?OCqemZRc)FPyA)AkI}wZ8~VuHdLCYT(aSm*89V&N2~a2tD*}!co$c(bjIIB|X#RoV{DWWkOBZl3 zw{x;}fMWW~%mF|V0JfEE#l(L|>)ZWr5&?iipbEvx$_Aj|00;yO0LRYy7dgNH#l+73 zkJEOPxY2}SW@i2;_HPPpD1h!i6U#r@nOLEi|Hd&h0ur%u{DlVm>oWkr5p4fB zc7WpiU-$Qz0g4@9{_oM>V*Y0QTL%t+@xS4JQ~W0<0Pz5*0>H#y@Bttzz&zmS-wXhs z;os2TxWA_U?*9n`O#ge92@v{gSR0Ci5sKsA+H$b|=TiP=|6kbv;eToZ*lR;^GDER4 z|5FY#`+v&R21Elm75@P-{L`%ez0?0o!~fF}GWym4g3$jH2?5}hm^wlMh(muzg^-Y~ z8=weUHZ~S0T4q3lGBE(Ez{tSB3Mi0X@GmC93D8ym($K%m2q!37CP1sui#zFCnHvh) zm|7VFDo!t~Z}%U55j&u(=oJB@NCk=maN=(n0D*tN$or>J{ts}&e+9ezUm*bhvMB(U z{~xrITe^1ah)q8MOvn|_w;=#QSv`52P^h%9G+r2;^#YV+TPg-C;Jf4STGDdHNkIHT z-A9^k^7-;+)1+u-N07QJ!i!l3WcBz_wm-aGZREi>c9kL9-lj~CsalfDMbh5+R(lia zzPhD&Ed#ZeiHsD8X5$q{2843%o3tmxq-Ji;FpV} zyn`r$Rw>YbfR{x7RQ6Xo{2}mzOT5)VRtX9B7m0#0;f14!sGaG{(vFHKE!FI0iQk1GpT#dAQrfG5ROp5JMf-|#OD=T zDz8o;NySpC|NaHS^k11d|F>@V=N|i~1OHbyTy#LXDTyqy`Yu!F-8g(W%aNMxxqTrq~hp&4Tdv zE!eM%wVy3fKQq{hUf#oL@Gs~Dp0SDMbSJ+SP#t}AUsf=GM1K6}{L@lrI?m;EnTC#x z>J6SL=oquDu{%jdH7UqrdMOf4hhK9>3BQ~78+`J0;r-5}x&fPxG_?ixrUiYXxt-HE zDB!u(DE^6Em#OVV*Ko^~7>`K8Yok7m35-oXBvuUb$LkIT(cOI?Pn6;bGEpLC@XegW z+A)2(-%Ck^wW%-29dY^#TS?N)WIlv5-`jAJ4I-g9ws8>0*iui!nv{!zL zD6HeAdRT9$S~z1C@@;$3;)>m(y?f6ghInZrmIgr6{Jh-)4j9~w9p z1{UVMRz1LvNQo=?@lE7Ktq8aj-bWiOg`U~Yt<;C>^t-)e!5vr`$DN%2Eb!_2gejOt z9v^(HEs{1a)a0sK#Gl6}<$j&RN5TwsF_~?+XX*Iyl}n#r@T!0%9XEZ8uSRxP$Mf@D zno?PVh&E;`+Xus%*rP3ZQ4xc=HtlgbE_y=u1fgRpR-Frb%%qm}9D0OqBKC3QX*6oY zS0(wDVIMbJYuF5t_Xqb$)IHUOA=$j{=;`t|E55eyjVU;B{S)RB%Yj_2GOr@1)LWuZ zqOFs?`>J6{b0=_ZaY@3pEd^{|QQut9`)}UjLn(2vCXOqCG~;kJQzT{av=K?E1r4*p zv&6R??h)_xOmdhxTE*?9?bfY5^%5C5qB2&0*74j)Zc{&lKCzt0@g9!kVM#n#>vst~ z!ZrmiY-hHPwJZ6HeJRogCq<~CSD;)VJcB;`I|liBD8%1K5&7PzqvkE=&uDzl&(~g_zKa}+_QtRsSAI)%A|<&2vi95L58-sfMEXGX z@cI?0k2b4@HYV0Ztp;0WY+75l%xT5e7D8*n>H(W^4;Sy2Gl|Bl0GV?XvB03-4G^xpY+x_`W(wWyj|8u_*3rl zCtra29ff!Nz-Yqg#wghV;{od&W?}fXY2XQ^V(1C$9ON8ioskT)X|m)vdNO<3mh!|Y zL4!<({72nuK|zXfb?0$QR_n*l6170@d|HLt+IU<+Au>e@U}AU!JTiS{FYyEiusfuw zyO|~wN@elgbQ5XFKwrc)VbB$Pvl8`KebfZu`O(c7ILaxBq8Z@DeuV_-+kwKQ0-9AQ z%X*Df>I;41V@C627Ufo1vPZr(sVm_v`XcV-U@JmvW3gof7gTCf$&;yA&CLGFJpEU2 zaeWoG5M0uGze*L8nU?!F{_=p$oG4;=3b+IIYT1i=H0IWb&mO!&+Z6FHm__Lj)>L3> zN96ktk6o22#4{_(MMEl@;i~bsV@8t@RC;!J)yn+5d)wFLdTcCKki`4GVV(ov1KtiT z`a!5Zs=Ns`SH}~l^5mx+Qdy!&!y~e4eatclz*y@$nP{SUqmjSnbqz7CL0 zTXn>@#0EW)`anaBj$iyuvzk_@i#sR$yp#BEzsGaoUW44#>8SfWI1F0n>sRXs?vqcR z3RNdQ2rtf#bsRqasP@TS}PUfnc=_4%E;3oTLC;^XY(0(D_ zPU_2G_0+ChbuR^>V#%f`>3c+pE1_T(!6zP$$Qne=#B2s^P-n;(LQjW-{mF|{Sa!q0 zNcqLr-eW{Mu7lputeILtpF|)29;TBM%1EB-P@*B7uPnVt6`I9nbsQbma{wv~I{KQF zcu!3S^+-!IJT*1}X=M81?ajnIJEzcr=PKb4OFffdSbyJ-1qs{_bkfZGLdC}I~6pH_q=w;vJd^meSFl)7N^TmdFh2> z8?kh2BfWhqO?$+P>H$KGP`PDfv&|l!zLf7jGv$R{ej3?lEAG?_JQUkmWT!q6&>Tgk zGE!@?=lCc&d#I=g+sh|M@X#|bOZg-21Z7hQ5sNVE-wjt&rQS^K<>e@U;)>u>+giBJ zju&QotLQEMaXjt2{R6%!YJhN|63VNJ_n7%7*I{Y+^$PWj5A)enJ&uDwbJ3&q0yL4bh2AlS%^6%sG-uCDGCpxR)`f*f;Q z3;8h9oMlqj6sv+{^<7e$r-(Q)jw$$vCj7vuhW%UiNED*$s$y?qdu11-jO%Q1i*WDw z2l7f?-b+IqNCR@=!kbuy6_b%B_e;DQi?5w(CL$DhM0wI2DKSaG7!kUJyNs$nUP)7E zg4R1LbU`L+Vop4f|BGesWZ}$=X&^22nV9>kvLt(<%Ll~~(mqMr%-D7GcF9dnQNQd- zhIhfAPbznXi+yuy)m531GJjx=mRhpyD)p6m2hp}pE0gXknDTW}>sYo5E4&4iv}qUb zr_y`7329i2el0ezQL33=xWJ7Dy>Jl~$t)E%m%FXYc| zYB}?ML!|KJ1(y;9NuoHUn2|y%M9~toI!dAuPjx)?9T?9MY$rVYMNX^qyX9(T?ofHs z-brM)d-@59af$JPz@8qCog{7qMrqnYl`{t@%sWUmDD4)Ew8L9ArjITic@tRZwM~-jIj$e|KF%MhudGj+i?+^Pi9Tq)U9&z4UHex2FL4ra z`S~v9%9{3H;_XR8SS=`%PHc1>cv^OcF@_OhK|eTxjNZ)EC;< zT8Sx$Fkn~-5GEd%5@h>Qmcb55m%@?@u(~e42V$!cvyvGnJj5u(PLyJEm#zHyULvS%`qw;bOb>>FxS${et&LBPO>W8wY6oy8$F zCLEcizy}M?XesfqDe5;peDos08S|eQaSSv9xl$NFT<~*yNr+ib=9g55H2* za1Ki;5FdlE&>1%6yT}Va{RK){vS|uGTx+{(pMOg>j<;K`f1~H^(0?l+p2Fp=9;jc- zcxbv9TRS2xbovx=C!3wOHBsKs3WZ(eQTY;h>AeR>t@&6rZ$nyRc17dS%`XRhiW-lly^1NIEGmUxdp59!u=!&!P!zk`~jZ>6~l8E;T`Q3Y~bUK~Q- zRZ7c9MN5t5CA@ZeoC7u%d<^JbSW$)eO!=7d<1{-v8y=a2RN@fw*U+rUUrx_3YEZ{Lv~na^NIbPOEn5Lj6V}fx(V<p1I(b!}ov zo6W&|tcTL&lBAbOG-uf>Wg?rv-+ioyvl+Hn1i{|$C!X-7)X)f4)CbH@DSL9<=;P6d zCEqVmzPp=!eYw?4*Du;HH_jp#Eh1@nb~fa*AYZs5u(k`VHr`7g5)4ezGvPu9i(&7$bsQOXDk9t+`_Weyd_dxL$|u~ZOY>jw>3maL$|TdN{%ro!=24VEVGFY?YG zSc^FJ&vfONg+}N7U8p+KU2#mHxT?1E zLvn0x4^6x6jWic8>I=y&*_-BIDypi~_Xu&<2W+l~FtN0(jJtuc$ZtomQP#}{YM0M8o-hxf)OvHTq@9i$z^+Tx~2F4^<5 z^N~uN3G})3HYRfCg7Vnkq22NwmKn$s5nu!qREWT@(8j>2H!V}Yo%XFD}vBXaZBqxfo-PBoQ*xRI(jCAy55dHdjQUp^3 zlu+|+Rf}~ZR)5r~C_)-PgOTj1YSNcP1EOFtKQn=y{NcSF`_@LXgV)+{welW)R!6ES zZ6Xq-iE|;tmQKq2XNOd)AV6<${k?tVQ|XwtC_Y7Tk-ug*Chyx@utWVh2#*4@@SUT~ z{!dQU6swj1nUgv>M=iLRNpTEtC% z2~0y!sJcSWWquy3cGFhVPO-;)xMWpS)UN=^Y6&zPbDSe99=tz;aN+8HAbt7zib$42 z1GoWtN;#U6Oe+B)#5K~C5NmQCCT0$kE> z3_qAkVX(5?ww8^j@mn3w*S6oLuSAccmB&-?)I1G`PUdClevtt+dP&d=tsDI0@Ad z1SZke0DrbOUYl>S2!DpP6R_7e$urA&4q7wR%ZIZg+&3(`Xc_5#=-%jF$$81S>5h3g z7*byDKM14^I?8j6sK=k%h)N`9@}skCmGR>DuxpH|MwNa^MrG_NwP|_t<-7XW6?dEe zf#KC}(-PO8lgJaAlg-aLTf8^h9)Fsk_9j9rNfg1~F6bOix6D}WQC&5!?8}{P*sNUwTL#%Wr%Z-q?k4Net>{;`>T&CR^L#Gsj7W^~&D2OKHM|Tc2oFl) zy(t~SrbGA=0tP(=3es3Th1$#LIw^I+o?4&rT9{dR`@x%kC5_yNf=$rYN_EeSCn4(d~An?phJP`S)3olcy!VlbxohLOBn5ZuuOJPxGQ+DR!GZkJ*x^z zq3C&7fPb`qAL!JR9gHTYAEkykV#hXfwTfNQgEcCigi}UmK^_p*W zpK1XYsS)7&MlA}GZ7N%gh98_!9Nh4VPA@&Z7@6)?@B}lQGj3%Ict~@hhJVD;;i);4 zptt5A55Gf5cG@5xl1`2BcG=0;=|C>hvU*6_$c{h4_am{9O*13#0X=&2vxVpH$63BW zsThyYBzxGq+r%1*52dM27@e2!(-@v!AM!SdcDgi6y1+FJi(`X zUNsKc@^^eVrBRv4LSD5o zZcR`pG?Dtv%}rjvy0)8;+9B34d!%z%u+QTfrq)QY7TX(HSRJnOK!AV#1V1(*oYD<8lr)&CiG$JU-7&G2M`XQZjh!@1XalM(~}fh<$j;wG>! zJa;+yv=n8dV{dtcRq|D64q9g0o^!_xgVpNIz20s8{T4f6%q0euw@j*D*?VW@N}n##ic~jg`}yXoZ9kQNQSi;(WFKtb~=;A+>G#=DDdUQ=6gw*G1XDi+y7co@D zOmxhJSyZ;Q-GLoi!Rjg}GqYDLwOhoJo2MFAikPcnn2f1gm&iOXn=^%z#upEAk3TkZ zG=m|HImW38^*bru8cGhMC7DJY-8yreYg3O*M)sy=w#EkKcQz$A{|Rw+M^-PQ!6hW2 zp(r4{)t?LsV+;?F`#D$2j&+53758eD#@l?^y7)_Wr0uFE#D`}^WRKl$g;&az$Jn*2EUt+#QcyNs(ZU|npXo%9hM zy$sIFIw=q5sq#3__6)QOBc+3Z>PN*MKzl~>aJ>C6ezoa4!^~E!v*@5R)Wg>nu z>Ok9C^U8LuUY{4mfip&nZm?6Gx_PX>{KhKQJkq>IoX&n>Xez?ig_6xOV8b-Yrs2$` zWR`D}Ze^`ttX@bt*k>$uMMXfRs9}E9O1a!zQu|DM)@ey$$z{nWY1cskFK*zO)Ev`N zA4yWWe0_k4f-1=7)T^EGVZ3iY;&ov9GOV_k0!_|(nYo0GM{a#M5jF&H7t07E;VDzN z6It$H@!s~H7KYTx@LM1EBavQJLYG1KEL2%Y?(1S0Qr@rpW_(P@oZ%Y(edtHJ6?Y6- z?pPt&N&Pk1slwVqrtWZIe;(#bc=plF9=W%2C^UpJWHppw@7ls&tj5a9uda3JupE0quC8d0U{4qI3;9%9x^o9N zfC4u?IhL$6vOO)XKQykDoQw1$*AVW?l<^nmT6%)DE8vGNW+yT1wvJI+QFAP2WaBT> zqmX`FEx`XN1~Zx*1;1$x#ZsikDKGa? z53!Fz>k#=Lt(^x{Q%$??L7IRRMS4e(UIL-_t{}Zj@14+lLKo?%6s32R4k8%3AV`y{ zfFOk4I{`vVZt#82x6b$a-F4PF_hu!}?D@|#douIv?6tD@?B5un|G_GIhV+eDj8}SA z4@BLPWQrN_6X)FNKP6yr{dk-O)%5$K|Kh=|Wy4;=1PS3!>C;dgEOb+YV8i1b4gD#MH%wT&RrGOk7Wsng1ZTj(QML#c%ca zY8vqzTSzFnwO%4xvbxxI&GSkB5SNL5mh+jHCtFHc57}v-Iukv^&xzqw^eSrS1xY%H zS>Eh|UC+sTCisg$0eXX5tPQtnve!M;pS*u+p4Td6?5=8_-YOm#dNgKB#wzjpy1%4f z)ysKil?UKZC~QhB0l4pHP06}VA4&T>ryur$-i2;WGJQaO@a|#y9VL?P#K&|gVT|rJ zMygBz0{i^InNTTu6*&q94w@!dqk}C$bxL>acIka z7=F&)`bcViu!h+@Q!vH&t(zbXIgWx>0?AY(TAXIZoeg zD!CEb%KWJoZ&pV#8$J1ACziQ%S}*#c9JlSqgN!FJeiRfjK5q~6(-(&E`vYGI8NaH0 ziuk8GZ8y%qf!PnxDEMJt(}b%7H}h~O$-y>3-z_v1a* zQ%sy%7VZ#IyHiV^`Xy*tt<3Xtgc?=o9ZPO54h}+112+lvK(W*uC)zh z_Du4!gX5P3NII!TR|e`TsGUdEhJ53Uee>lZgLfUlxBzZ+fRYaVWUBu2@I(`yE2rfQhT8CI!9mWIRp$rq4cFu~CeZau*x)JZbSA zN{9Hw?AS&gPB0Cn$9C|0JT!PD_FCuABalno^c(471Not&8POB@dFR#y8;|x+jqwhj z`$V`&p`NyTeIL>%)u}SXtrv#8b<w6{lLrSe8d!5KloXlS7V(%1xW`3RF;yZS<>T>5jFZi(A%*sb4 zv>}|em{b_q_$!zWSPs;0u=T>%SuF{--VATwg-bQdg3~k@ob`l#b{0wG5id z2?mPKG~@eXleW`to@F(S(nL)$cZ6qrkK=XKOIsJUY0Z=~=4wkDW!&ld{k(EC#<-Q_ z!XaOJ6|AaYrrwF;cv%f zl^sbz&NJeMO|s7EqGQ$g&m9=PfH4gphQj-@NSu_-qE6k;iB=N8Z7D&ftr_|H1|`%z zApT=4bDT*JFji(5fjRE(tk(P^IJtPY@9S;6DWVBNc|-L4e3h?ElUEMe>ock_mHGoG)8 zP)&F)M0Gi}$tn9vg~p-z7F;+1L^;mSu!`Yol+tS={*KgJjG@*4EPq?{JJ!yv2H%Q$ z@_JF(GeRgsN_`sEknK!)s4L6qE2oME#i$OI;0RyEW6{-(+%vTaltjur(5*!5cBr*@ z{x;Q!TZsC*yn3BZ#-hg{oODbuk6h!h7qg3DS<1vypF3}A9D!(<9&}tTO_jhYMVEKH z?2RU+pTE#nXgF9^#HQU{OP8G~^v>*y#$u5=8H`rQs9?+H7>D?jZiY3mv_=YjYt`Lf zPCfsgh!Q{swW(muHENXkhBDPbjnWZ5~W0Dd!?ik9Wj ziV#1g@v|kN`)*@Y#P}9AdiE~ug5lf_oP6#ME($6@7M=sHr4w3eT{L%J%m}pSS_gAv zkijTzF5ht^uTh=uFlbZk#)EyzYq4zGZ*A}2%8U&{^_}Ur6R%#@?lEb8_fM>U3n^(; zs~9`O4h#mGy;WTPq^`%SM73GFDJi32&U3`ZX)#`+%G{MDw_l?8K{X>YZ&C81f@QbO zkwH^Xy{IMMpLHM+E-^t|zV6-gm6_vBoJtSjKyS?X)1b;6jcAMb_ z*}B_A$Lg|f%&V^tZ9-b_bOw4f9@8EfaoLVB+xnD^TJ=RtIUzqBh@hXQTb-KFr=Xp) z>5liv$(HfH4BzW##q;?d|pGbg8VUpKVuHfmQf8x1F05; z1`MtxQt6|`Uo<7c@w;ReFH=kR*SD2+nQF=lz^RL1QY7O{8cZb=R7PIQs-Ld8*<`N)%=r+R5~+%@QKTgnTc-l>Qq{>k|^JysA)mH-~I zo6RRtPkwOKbF}5bw4iNs+}AV~&Xi=Ws;k85(DW=CImU2>?RuPPQo!ZCeZ^@A4sRO zZ-}b;;jd6FB*otxY--$PfZ-cf$X$RZ8!sn=NsMuZ861w%!5u$&cj+`L4B=f!fsoAN z#9hUNv4wYhJ!+0@I#$@6D+tNo-(wI-)5eL>!_G4Ian5B&VMM16TE(cH%+Vc{*e>&b z)h%!?kl!>b*2J+${`5XibDy5-Pd09a;WzmW2z>6BO5n-0cSYI zGo$`uH!n!PXhl~SFcNqT72V32q6bL5kmGrfsBkVZfF-p&O#&nqoT@9_|JpKh4R~5g zx}sO~RM}$w8uDwW_#W|l>OI=Me{VFUYb=@+VxZ5H3sgsn`nb55F-09Q!pRX7>Cv4nOkw!Tv zK1`Q&y9Jzub5zl;>-M>pL>XvsUKj-5Lao>}n(L>DQMalFhNGQoc_lK?x z_awYg9g zZF5l4L4iV|g1+dl>T5U7%XW`!dTX|=BB2K)%z|TQKii0tP}UzsapcSv|Zz$AW$IGO%O zNt*k>rdWaX7cHwzPINn0blI`%T>G1|UnJ=Gl4vr%%on|A`Ph=^&#`&s{`fk(<8q7< zu(w`XZoU&~F@E2h{N(OFbD{s?2dz^{hzRTTYZQ!niTJ0>s5L~U=Zwp{VA^J>W$pn9 zCQYi`cglDC^SU`aq4CWkx99bja)$ix8!!~SzcPE{MmT2$4&SS_l$?C;(~&86#+5O; z+sNE7+_j$~=tu;svA*QAWjqgWk%45qo!HgTO@NkZGF^z=_WLLg8*Q7T%CmcDD+su7 z@B8F{dIG~eNyX~v1S~`{w(ozp5ZaTiJ$Mc9ufYjooog- zL)Wpec9Y}Hky#_`2_aifitBJ_=?%c6y!9%vp}&joU7;Yq#qxR1TPLx=v^1dhOrgw_ zKCblZPc~PtM}MZZ=BdKl%)~tkB(uv|k#(~ssoZ-HQQY&3W*|#^xKB}l!|D4^gjtUB z1P8$?IgjKG7_+e_4JhFxsc4?NLhIQ?o}fpKfZT>dc|q3ZfQVj&D~0&7z&!pVb{7oZ z7W<72CM?497Lh}UrWSLBCi7Vu&^+NK?3-Kkyu9=g))M4k$H-WXRows-)$Sg%XBz3} z(&u1~NB5v6km;&Bjp~g3z+8c=I8>V&uSF@;P={1mMbx%X04G<>mE_RB<0_~2(A*8+ zf!Es&a|8FIR|g;Shth$?ig~yP12m}!*_66jcI~h%Xr`x$*vn942o3Qr5V3Yg6y(~l zL9?E$c?N6Yt&=RSz`2C%DtF(OUcwIkf!DKfvd8PdCbk54L$0iN*@K{4n-3spjETYC z*Na1&%|NMecq=Lza%I#Oo+}%)z;Ou3DeBBL1hd;(5}u_U~|4%yPOl^`dTC6Oe*3-2h#GpLtF!G{V!6t zCbD%!DROq!m=-bQxS)w&$*e zA;aXy!@lcl>>AkNf{J-#7Nd%z@M(W6`A2+{) z8^rDzl*NCz)UuO`ZU*yd4Z&pIT9mO&d`!J3(UJvwZpzn1N0SOm{IZAj^lUmj;qUF{fJw9@T8z3QmEjszcsV_= z#KW_IuKf);8}QQpA$keV@7~o19nUAATWQlKj0H5u-oSC-J;;R-jnk6{9T%h}UZ|%z zSaT~#H09$^w{FwCi)Kp2I@kPG+nw9A+ri~^Ysg!!g05e`I5rz)J>G8p`nUrFLB!9O zIdTjy8S_|QAG=VR2NkYLK&zJa7al@XZ1M$zKo2H&L?nNy54%KbqQEr4{-H68O-SH1 z$^IyUW(}ry-6?_0P3>T0pozBuw=>|Jt-@t`t~PRk7wr~^+N(BH8SZtKoW){uFSsqY zVf~Ct=2U7h`LeE8Jbp~KBY5)(ZFBqJId!?urN&y!Mre>f+x%3wM$g7shYbJTO7pwL zjI%4^^@nXlkQo?}jIeZ2&0~>xkT>L(8O#IxKGX_{zPgr(jS!KN!S-fM-Bu`;E)Ok` z4tf;_x&yhjX7bse>y$DN8Qn9Lqa%QK(JD>=W;>ue@Lv$jC)!APi;PC(kxy+{u)8$t z?$En@IbVrsugP+3XWX4n_AAYoGV(lOg;N^HhD;ogg^g%%Hd+ln1@!WrnexT`f1Cm*8rK?&8>nUS)6*=Yey6gWsY z{5Gy)+GM+dfj0vP^9$PgoPnDV0KcSdo45*xXwg@o5^J!KyCisLC6LF`Tdvna8w;@S zAM-ULd;#+zk#Avh2Ow7msz$!S3(nJ;te%{Ny#14OexEGM(A-k7vWt+5huyfx z`gf-BD24}J?n}6iXIv4z$ z+4WjR==LSa0^tD!H_v;ZALJ51|BG=2^74vd<#rYesFRu!P0}zK<VnTGxA-M+B?X3Nj7-&NiYbuf<8W7rFHjVNeRLGBw=NHYY4!8#!OEn(Ba zUXhvR#RL;>diH)3LZy1;J;wP1*h_fl(LWjcJGPB=SVPwak*j22tV9+we zH0s;s`#f@_*ik=X%}C3M3x$_TwST|PKMu4>eKy~yEG2Mw4V-*7G@_riVSgfrcttY) z^EpqZ)h}PxBQX6UfhU{Z;|GN>t5GI%N3*sfs5Mo- z^?gqC?31%+5&l4aTqEl~5AH2+{_JI04xLoTPXy}n+F;~NSLU>RW4-+w0tympr$j~X z>jgnt4!vPsSEg6N`<9S3@Y|zYB=D>NO0O>%Xe5JiC8zubtGV*QzT1Wl#ef+XeNQGB z@3Px}MQ~4X;$`lg%e2an2XIhegEV3tx-^^y`zGRl=}%m-&@#LTdEaUhnlvkRNY4^0 zF$kF3bFOIKbBdY3wsNu3O>CLTIu}!)X37eCJ4@{zIjvof`DHiZQ-*0&fWdd(!&)B> zgKnPNQ_P`HziW8XKGeh;@hJw=el+f~_nPgUV%|xs#9U`V#$rx0ZSu1-p`>R>HXk6X zA9gk{9l!FSE`^jkFv38dB|WwSFa^8u(r! zXwAQ^`y3B7$%yPCTuDM^Q7-?o3!!ymU>fG?+Oy=HXqoGXNsw4ic~e8naaWl2I(Yro z+qGk8V$LRW*?J;K&DtmEzBL+Ea-}t`s+?FzG%c#IKh#yTxGn+lo4yC8w93A=b`WbS zFu|dkeYO+Hh44rALOL=!K!5`Y;zOc-HJ?00`Z+7|l8GfotNQ!K6lOlwGzpy~dCmy! z`lk;q!wM}3lrJfCDmw|@)aNQjyBW&%VKnwmX4ig?7@~t~VPQYMSx$H-U=)G7XoynS zLBe8EuLtuHM08U4A9)GXHu)mm0`7Yn-uubN6xRVHSXV4mOtL9saT3%Pn6x9%P}yVE z4)BmyRS3F>FJn2zLYh^@jA1 zq!b>#@_TA-u$s~|=7;e%;3Awt_L2JhWU2T-gP)dReSX!Rb%GLCE-=rDeRwbC*rx|Ok!zb+WR_`sgw=3AWhIzB{LmtJnz_cs zo8afKx%3phQhl*iiTA`Q?A&=z=)n2_DWfz_C(DU8WxO_&(qahPg6&E>Ipih0zpCVLr{iO2^VzuZ&aQr8KuNZIbYJzWIX&EGJZ0L=?2Vs@~;7kLENOr>XQDA%(b4m-_(sX}`qIB3`w2(BkQbwzNI!`V(Z;n~Aa#3^ z)X{Flwy-%zCL@D>)uLMxqCSbULXr+@bx!fLlc(_Thg)Fe;5$xD-h3U|h=dH520iv` z)hGk&NdxQBn|GluR{{7T>_s-McBx|ftIX9wN1<8N!Q`zk?qc- zGD`2~nRk8DgV=vBIugE9rQdMsqUW%JU- z*=gY%T6~T%4KK{G!KO5i733HNH@wHqtCz;f#x-c}%uLdZw2->%i~XE;URxP6i0j=_ z4GAt$DXB_*UYOjlR>JJ-ZDO9pXGOAQgMR$6aPnzh#lRz}zOA2t`_2WC*{{@QMxPCH z8P7FO?a#7Sr5m`+d+``@O$rZ`KudZ`XzfC)JEkr6U>h}Fs!g3?l^L~=68(Vm*A@|{ z7&Nudn(OpANDO_gpN5#;&U#o!VX(Pb`-b{2%AGmk`p9*zs6 z*eV>Ns-Lw9l5~C>kJN}Fy&H3%t|X$Hgjh!OYTN#VDN_*?P$NgNpZgR!%ok^Ay~7a-9&i#bpHd6xOR3!pQwd*OEEgF8vJea}+T_5cw`C zT%BHM4_}C<;#(-XW&(QefKwEhn+r_(zOiJ#F9AgSQ8r zA6;4H=xeGCn<_Nlm+o#c>#Oo7i`0mKL}cuQ-9#7#)R5&q{meV;PYLHZbJY&7=sF)o zIF#JhS7mOXjmV*YA;Zj3%*t>WI-jw-J(#;S*m!2wbZVz>u<28JTW!6}3!mQMBvT6X zX`e|eH$#RleKS2LFEFGkFfwG<6_k44S#ESBjJ{KOfhM1h9V;%}w3BW8qA;W_-AxZl z*@st;Ap?r|<<8>8`mDk1RivBiWn^S!YtGfJgI%X@d^k>^Z;T{IiY~~f7HMg_#UXp- zbzcGACUQ+p7sL-PU$Xz4+?vIzhyI|v<>dNpAJ5Cr1zFL~iI*iW9sNN=c;jJ0)(SuVLOEHGW6s1Fjqlj-8Is%S#caPd6Cw&{X1z=?46IM%_QlVVN~ zGkd^oV|2a6sI99(6m}Q3+dk5XU z%Kg=Q`Sp(5N+}@NFTd(#-GN!-!bwA^@@j(-1C@6QKf*GrPPGMp*v1boHGDDIp~&j7 z2AJHBw`EonKoRxIAtgLte_JwK&Ll_bDNb`7mLV9ywfcZ~MH`Onhf}H_Un|$nczg}i zlbB$XmgX28b|ou~Ru8|+y;EK?)Bm6=GVscNvE6)Kul@ zG0wmGNy@v7DfOX1IOOCZO6;Vi7`q=(9+Yu^^eJSN>o?k+Aou1d%IZK;#X*t zaQ(=s*=Jt;>=$Tve7h21^mgS|Dn61V)IRxkJ^DeX3`q%RKTADb%kQrF9Iad9(D?)O zEhdWllankm3Rjj+|8(ho!k&jfbxal7@mIs)mJu7i^ z`7T*~=6fQ+4U9l-xya2Hi_-dh;*hGaikkf?nau^SE-;>F-E1MZ&f?&mCT+x1j-tmH z;jHGYUNSavkmO?M$5&4G?;ov%&$)hhxvJ=8ZHe8ZsMf0G7O)UC<1tz6bYiOCHIu-Y z`yu*D#WvXKd#m%6QpigTqz~T{!q5YUF_LR;^VR_7P;O%0hvh_PIeN4VKujrZec%th zxv&OE9nhGi;hIf&k8n?2g@7~tZk5wKh zPY=(keevF}7m6_D_O!-y*HkBM538r7w|SkzYR~t9D#0X;_O@})O#+yxTyOK<@951A ziW7IdD)+hFrdCUs_>3}ci?Zj@!_4;4ci4h=-N2p+Zgrm@3Q!W3IL;xk2@q5eSMB07cPWKruA0=0^H=Ru7u$^V>~Ggh2R)>LV=r zH#z4oMvpKWT64qZK^LI6Zaz2Y9Wyivh(P29FN0>)m=K7HilL<&RFB|q7c{ZwhD#)f zmT&6b@Q2XFq6A_$s2(u@ftbK={olwuHy!yiRR1%X=U>tJ&v+gI5j37h1kD-x3(tcd zpZ_hMN9g~8=eZf1f5Y?KjN(mC1cmwk#T~i{z>Ov#D9SJLFZE^se|w_C_&1^GH!S8x z1BeI+iv5>*LxZ{zB_#>?{$f>WSvlBw6Y!}y*?Ql^3tg^r(=!E-FPb|9_;1Lg053ax z0)8em;|V{bm2l51j)D|KSG!3jCo7 zhzb3n0fqVhp#elh0lJ-DkFf%X3JM7RV}C$F0Pvq0K=jWTp?{5ki~&$k z5ctpj3Il*5|J)x?5Z#e~v?~k{7Wrq~(ApodL?17Je~uMe6Z~(QkCzo1Ugh=sfYo&h vvb$l7@#%m-J_I+XWc2y&H!JN1Wc4=~$;!*;_mL_L5Cw=5u(B#?Dii!4;$c`x literal 0 HcmV?d00001 diff --git a/bytelatent/.DS_Store b/bytelatent/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5a50b07181fde5f726c05ad04620e1e18f4fc419 GIT binary patch literal 6148 zcmeHLJx;?g6n=(4tFm+fF=k}~76w>CRd%E<2!x{q6J=)y2iUM<=OP8*$RWNO5ofru%aD zwC(b0{lKet#;Zk^EyOewP>bp`q`F#3^qAPVGRI5~N2{ZZ!#K8+d}eozTY1GZ=bfVt zbZF5D#a3o|-?B%|+-r~H+(x{a-R)uLU*r5m>QM{^LwwG0KiW2nrIE|WF_go&fgn6oxY#X)-+$6=L6FJ8N;5 dict[str, Any]: + return np.random.default_rng((seed, rank, world_size)).bit_generator.state + + +def distribute_data_to_rank( + *, + dataset_path: str, + preprocess_dir: str, + entropy_model_name: str | None, + arrow_batch_size: int, + rank: int, + world_size: int, +) -> ArrowFileIterator: + dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + n_workers_per_chunk = world_size // len(dataset_chunks) + rank_to_arrow_iterator_params = [] + for chunk_path in dataset_chunks: + for worker_id in range(n_workers_per_chunk): + rank_to_arrow_iterator_params.append( + ArrowFileIterator( + file_path=chunk_path, + worker_id=worker_id, + num_workers=n_workers_per_chunk, + preprocess_dir=preprocess_dir, + dataset_files=None, + entropy_model_name=entropy_model_name, + arrow_batch_size=arrow_batch_size, + ) + ) + return rank_to_arrow_iterator_params[rank] + + +class DataloaderArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + root_dir: str | None = None + sources: dict[str, float] = {} + batch_size: int = 2 + seq_len: int = 2048 + seed: int = 42 + add_bos: bool = True + add_eos: bool = True + load_async: bool = True + prefetch_size: int = 64 + preprocess_dir: str | None = None + dataset_files: list[str] | None = None + entropy_model_name: str | None = "transformer_100m" + arrow_batch_size: int = 100 + buffer_size: int = 64 + + pad_to_max_length: bool = True + max_encoder_seq_length: int = 12288 + enable_byte_ngrams: bool = False + + tokenizer_args: TokenizerArgs = TokenizerArgs() + patcher_args: PatcherArgs = PatcherArgs() + + def _create_sequence_iterators( + self, rank: int, world_size: int + ) -> dict[str, SequenceIterator]: + sequence_packing_args = SequencePackingArgs( + output_seq_len=self.seq_len, + buffer_size=self.buffer_size, + ) + source_to_sequence_iterator: dict[str, SequenceIterator] = {} + for dataset_path in self.sources: + shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size) + arrow_iterator = distribute_data_to_rank( + dataset_path=os.path.join(self.root_dir, dataset_path), + preprocess_dir=self.preprocess_dir, + entropy_model_name=self.entropy_model_name, + arrow_batch_size=self.arrow_batch_size, + rank=rank, + world_size=world_size, + ) + looping_iterator = LoopingIterator(arrow_iterator) + preprocess_iterator = PreprocessIterator( + looping_iterator, + patcher_args=self.patcher_args, + tokenizer_args=self.tokenizer_args, + ) + sequence_iterator = SequenceIterator( + preprocess_iterator, + sequence_packing_args=sequence_packing_args, + rng_state=shuffle_rng_state, + ) + + source_to_sequence_iterator[dataset_path] = sequence_iterator + return source_to_sequence_iterator + + def build_from_rank( + self, rank: int, world_size: int + ) -> StatefulIterator[Batch, Any]: + source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size) + weight_rng_state = get_rng_state(self.seed + 1, rank, world_size) + sampling_iterator = SamplingIterator( + rng_state=weight_rng_state, + source_to_weight=self.sources, + source_to_iterator=source_to_sequence_iterators, + ) + tokenizer = self.tokenizer_args.build() + packing_args = PackingArgs( + batch_size=self.batch_size, + seq_len=self.seq_len, + pad_id=tokenizer.boe_id, + max_length=self.max_encoder_seq_length, + pad_to_max_length=self.pad_to_max_length, + enable_byte_ngrams=self.enable_byte_ngrams, + ) + packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) + mp_iterator = MultiprocessIterator( + packing_iterator, n_batches_to_prefetch=self.prefetch_size + ) + + return mp_iterator + + +class TrainArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + name: str = "lingua" + dump_dir: str = "" + + seed: int = 42 + + # Number of gradient accumulation steps + # Total batch size is batch_size*grad_acc_steps + grad_acc_steps: int = 1 + + gc_collect_freq: int = 1000 + probe_freq: int | None = None + + # Nb optimizer steps to take + steps: int = 1000 + + data: DataloaderArgs = DataloaderArgs() + optim: OptimArgs = OptimArgs() + model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + distributed: DistributedArgs = DistributedArgs() + env: EnvironmentArgs = EnvironmentArgs() + + checkpoint: CheckpointArgs = CheckpointArgs() + profiling: ProfilerArgs = ProfilerArgs() + logging: LoggingArgs = LoggingArgs() + + # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus + async_eval_gpus: int | None = None + eval: Any | None = None + eval_on_gpus: int | None = None + + def dump_to_yaml_file( + self, path: str, log_config: bool = True, sort_keys: bool = True + ): + model_dict = self.model_dump(mode="json") + yaml_str = yaml.dump( + model_dict, + allow_unicode=True, + sort_keys=sort_keys, + default_flow_style=False, + ) + with open(path, "w") as f: + if log_config: + logger.info("Using the following config for this run:") + logger.info(yaml_str) + f.write(yaml_str) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py new file mode 100644 index 0000000..f494a15 --- /dev/null +++ b/bytelatent/base_transformer.py @@ -0,0 +1,585 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from enum import Enum +from typing import Optional, Tuple, Union + +import torch +from pydantic import BaseModel +from torch import nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import ( + BlockMask, + _mask_mod_signature, + flex_attention, +) +from xformers.ops import AttentionBias, fmha + +from bytelatent import probe + +flex_attention_comp = torch.compile(flex_attention) + + +class InitStdFactor(Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 + + +class BaseTransformerArgs(BaseModel): + dim: int = 512 + n_layers: int = 8 + head_dim: Optional[int] = None + n_heads: Optional[int] = None + n_kv_heads: Optional[int] = None + + ffn_dim_multiplier: Optional[float] = None + + multiple_of: int = 256 + + norm_eps: float = 1e-5 + + rope_theta: float = 10000.0 + + init_base_std: Optional[float] = None + init_std_factor: InitStdFactor = InitStdFactor.DISABLED + + max_seqlen: int = 1024 + + +def cross_entropy(pred, target, **kwargs): + return F.nll_loss( + F.log_softmax(pred.flatten(end_dim=-2).float(), -1), + target.flatten(end_dim=-1), + **kwargs, + ) + + +def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [ + d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) + ] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast( + freqs_cis, xq_, seq_dim + ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def lengths_to_start_ids(lengths): + doc_start = lengths.cumsum(0) + doc_start = doc_start.roll(1) + doc_start[0] = 0 + return doc_start + + +def lengths_to_local_ids(lengths): + assert lengths.ndim == 1 + nb_seqs = lengths.size(0) + total_seqlen = lengths.sum() + # This gives the document id of each token + doc_id = torch.repeat_interleave(lengths) + # Compute document start for each document + doc_start = lengths_to_start_ids(lengths) + # Compute document start for each token + doc_start = doc_start[doc_id] + # Compute the position of each token within each document + tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start + + return doc_id, tok_id + + +def generate_doc_mask_mod( + mask_mod: _mask_mod_signature, + lengths: torch.Tensor, + kv_lengths: Optional[torch.Tensor] = None, +) -> _mask_mod_signature: + """Generates mask mods that apply to inputs to flex attention in the sequence stacked + format. + + Args: + mask_mod: The mask mod to apply to the documents + lengths: Lengths of each document + + Note: + What is the sequence stacked format? When assembling batches of inputs, we + take multiple sequences and stack them together to form 1 large sequence. We then + use masking to ensure that the attention scores are only applied to tokens within + the same document. + + Example: + + - Square mask + doc_mask lengths + a a b b b c c 2 3 2 + a 1 0 0 0 0 0 0 + a 1 1 0 0 0 0 0 + b 0 0 1 0 0 0 0 + b 0 0 1 1 0 0 0 + b 0 0 1 1 1 0 0 + c 0 0 0 0 0 1 0 + c 0 0 0 0 0 1 1 + + """ + kv_lengths = kv_lengths if kv_lengths is not None else lengths + q_document_id, q_token_id = lengths_to_local_ids(lengths) + kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths) + q_max_idx = lengths.sum() - 1 + kv_max_idx = kv_lengths.sum() - 1 + + def doc_mask_mod(b, h, q_idx, kv_idx): + q_idx_cap = torch.minimum(q_max_idx, q_idx) + kv_idx_cap = torch.minimum(kv_max_idx, kv_idx) + valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx) + same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap] + q_logical = q_token_id[q_idx_cap] + kv_logical = kv_token_id[kv_idx_cap] + inner_mask = mask_mod(b, h, q_logical, kv_logical) + return same_doc & inner_mask & valid_idx + + return doc_mask_mod + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta), + persistent=False, + ) + + def reset_parameters(self): + self.freqs_cis[...] = precompute_freqs_cis( + dim=self.head_dim, end=self.max_seqlen, theta=self.theta + ) + + def forward( + self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None + ): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + x = probe.log_stats(x, "resid") + output = self._norm(x.float()) + return (output * self.weight.float()).type_as(x) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "fmha": + assert mask is None or isinstance(mask, AttentionBias) + output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + # This uses B S H D instead of B H S D of pytorch + + elif attn_impl == "sdpa": + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError( + f"Attention implementation {attn_impl} not supported" + ) + + output = self.wo(output.reshape(output_shape)) + + return output + + def reset_parameters(self, init_std=None, factor=1.0): + init_std = init_std or (self.dim ** (-0.5)) + + for w in [self.wq, self.wk, self.wv]: + nn.init.trunc_normal_( + w.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=init_std / factor, + a=-3 * init_std, + b=3 * init_std, + ) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + mp_size: int = 1, + ): + super().__init__() + + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + assert hidden_dim % mp_size == 0 + + self.dim = dim + self.hidden_dim = hidden_dim + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # B S D + x1 = self.w1(x.view_as(x)) + x3 = self.w3(x.view_as(x)) + output = self.w2(F.silu(x1) * x3) + return output + + def reset_parameters(self, init_std=None, factor=1.0): + in_init_std = init_std or (self.dim ** (-0.5)) + out_init_std = init_std or (self.hidden_dim ** (-0.5)) + in_init_std = in_init_std + out_init_std = out_init_std / factor + for w in [self.w1, self.w3]: + nn.init.trunc_normal_( + w.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + nn.init.trunc_normal_( + self.w2.weight, + mean=0.0, + std=out_init_std, + a=-3 * out_init_std, + b=3 * out_init_std, + ) + + +class TransformerBlock(nn.Module): + def __init__(self, args: BaseTransformerArgs): + super().__init__() + + assert (args.head_dim is not None) or ( + args.n_heads is not None + ), "Should specify at least head_dim or n_heads" + self.head_dim = args.head_dim or args.dim // args.n_heads + self.n_heads = args.n_heads or args.dim // args.head_dim + self.n_kv_heads = args.n_kv_heads or self.n_heads + + assert args.n_heads % self.n_kv_heads == 0 + assert args.dim % args.n_heads == 0 + + self.attention = Attention( + dim=args.dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=args.rope_theta, + ) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + h = x + self.attention( + self.attention_norm(x), + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, + ) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self, init_std=None, factor=1.0): + self.attention.reset_parameters(init_std, factor) + self.attention_norm.reset_parameters() + + self.feed_forward.reset_parameters(init_std, factor) + self.ffn_norm.reset_parameters() + + +class BaseTransformer(nn.Module): + def __init__(self, args: BaseTransformerArgs): + super().__init__() + self.dim = args.dim + self.init_base_std = args.init_base_std + self.init_std_factor = InitStdFactor(args.init_std_factor) + self.max_seqlen = args.max_seqlen + self.rope_embeddings = RotaryEmbedding( + theta=args.rope_theta, + head_dim=args.head_dim or args.dim // args.n_heads, + max_seqlen=args.max_seqlen, + ) + + self.layers = nn.ModuleList() + for _ in range(args.n_layers): + self.layers.append(TransformerBlock(args)) + + def forward( + self, + h, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ): + + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + return h + + def reset_parameters(self): + # Either use fixed base std or sqrt model dim + self.rope_embeddings.reset_parameters() + + def init_weights(self): + self.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(self.init_base_std, factor) diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py new file mode 100644 index 0000000..bcf591e --- /dev/null +++ b/bytelatent/checkpoint.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +import logging +import os +import re +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import torch.nn as nn +import torch.optim.optimizer +from pydantic import BaseModel, ConfigDict +from torch.distributed._tensor import DeviceMesh +from torch.distributed.checkpoint.format_utils import dcp_to_torch_save +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + get_state_dict, + set_state_dict, +) + +from bytelatent.distributed import get_is_master + +logger = logging.getLogger("CHECKPOINT") + +FOLDER_NAME = "{:010d}" +RE_FOLDER = r"\d{10}" + +RE_CKPT = r"__\d_\d\.distcp" + +CONSOLIDATE_FOLDER = "consolidated" +CONSOLIDATE_NAME = "consolidated.pth" + +CONFIG_NAME = "params.json" +TRAIN_STATE_NAME = "train_state_{:05d}.json" +RE_DIGITS = re.compile(r"\d+") + + +class SaveEvery(BaseModel): + model_config = ConfigDict(extra="forbid") + every: int = 1000 + keep: int = 0 + + +class CheckpointArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dump: SaveEvery = SaveEvery() + eval: SaveEvery = SaveEvery() + path: str | None = None + init_ckpt_path: str | None = None + continue_training_from_init: bool = False + + +def _get_key_step(name: str): + return int(re.findall(RE_DIGITS, name)[-1]) + + +def consolidate_checkpoints(ckpt_dir: str): + """ + Consolidates all FSDP checkpoints in a directory to a single file + Consolidate checkpoint is saved in a subdirectory of ckpt_dir + + Parameters: + ckpt_dir: str - path to the directory containing the checkpoints + + Returns the path to the consolidated checkpoint + """ + consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER + if not (consolidate_path / CONSOLIDATE_NAME).exists(): + consolidate_path.mkdir(exist_ok=True) + logger.info(f"Consolidating to: {str(consolidate_path)}") + dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) + (consolidate_path / CONFIG_NAME).write_text( + (Path(ckpt_dir) / CONFIG_NAME).read_text() + ) + logger.info("Consolidated !") + return consolidate_path + + +def load_from_checkpoint( + ckpt_dir: str, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + model_key: str = "model", + optim_key: str = "optim", +): + if not (Path(ckpt_dir) / ".metadata").exists(): + raise ValueError( + f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" + ) + + state_dict = {} + if optimizer is not None: + state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer) + else: + state_dict[model_key] = get_model_state_dict(model) + if model_key == "": # If only loading a model directly, the key should be empty + state_dict = state_dict.pop(model_key) + + dcp.load(state_dict, checkpoint_id=ckpt_dir) + + +class CheckpointManager: + def __init__(self, args: CheckpointArgs): + self.path = args.path + self.dump_every = args.dump + self.eval_every = args.eval + self.init_ckpt_path = args.init_ckpt_path + self.continue_training_from_init = args.continue_training_from_init + + assert os.path.exists( + self.path + ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" + + self.existing_saves = self.get_existing_saves() + + def get_existing_saves(self) -> List[Path]: + folders = [ + p + for p in Path(self.path).iterdir() + if p.is_dir() and re.match(RE_FOLDER, p.name) + ] + folders.sort(key=lambda p: _get_key_step(p.name)) + return folders + + def clean_up(self): + logger.info("Cleaning up checkpoints...") + dump_folders = [] + eval_folders = [] + other_folders = [] + for p in self.existing_saves: + is_dump = _get_key_step(p.name) % self.dump_every.every == 0 + is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + if is_dump: + dump_folders.append(p) + if is_eval: + eval_folders.append(p) + if not (is_dump or is_eval): + other_folders.append(p) + + logger.info(f"Dump folders: {dump_folders}") + logger.info(f"Eval folders: {eval_folders}") + logger.info(f"Other folders: {other_folders}") + + if self.dump_every.keep > 0: + dump_folders = dump_folders[-self.dump_every.keep :] + if self.eval_every.keep > 0: + eval_folders = eval_folders[-self.eval_every.keep :] + + folder_to_keep = set(other_folders + dump_folders + eval_folders) + folder_to_remove = set(self.existing_saves) - folder_to_keep + + logger.info(f"Removing folders: {folder_to_remove}") + + if dist.get_rank() == 0: + for folder in folder_to_remove: + for file in folder.iterdir(): + if file.is_file(): + file.unlink() + elif file.is_dir(): + assert file.name in [CONSOLIDATE_FOLDER] + for f in file.iterdir(): + f.unlink() + file.rmdir() + folder.rmdir() + + dist.barrier() + + self.existing_saves = list(folder_to_keep) + self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + + def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + path = None + for p in reversed(self.existing_saves): + if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + path = p + break + return path + + def _create_folder(self, base_path: Path, folder_name: str) -> Path: + folder = base_path / folder_name + if get_is_master(): + folder.mkdir(parents=False, exist_ok=True) + if dist.is_initialized(): + dist.barrier() + return folder + + def _get_dp_tp_mesh( + self, device_mesh: Optional[DeviceMesh] = None + ) -> Tuple[int, int]: + dp_rank = 0 + tp_rank = 0 + if device_mesh is not None: + if "dp_replicate" in device_mesh.mesh_dim_names: + dp_rank = device_mesh.get_local_rank("dp_replicate") + if "dp_shard" in device_mesh.mesh_dim_names: + dp_rank = dp_rank * device_mesh[ + "dp_replicate" + ].size() + device_mesh.get_local_rank("dp_shard") + if "tp" in device_mesh.mesh_dim_names: + tp_rank = device_mesh.get_local_rank("tp") + return dp_rank, tp_rank + + @torch.no_grad() + def get_state_dict( + self, + model, + optimizer, + ): + model_sd, optim_sd = get_state_dict(model, optimizer) + return {"model": model_sd, "optim": optim_sd} + + def save( + self, + model, + optimizer, + train_state, + config, + device_mesh: Optional[DeviceMesh] = None, + ) -> bool: + + # When creating directory check if only rank0 or is there other solution + path = Path(self.path) + curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) + logger.info(f"Saving to: {str(curr_save_dir)}") + + if dist.is_initialized(): + dist.barrier() + + logger.info("Saving...") + state_dict = self.get_state_dict(model, optimizer) + dcp.save(state_dict, checkpoint_id=curr_save_dir) + logger.info("State dict saved!") + + if dist.is_initialized(): + dist.barrier() + + if get_is_master(): + config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + + # Add json dump here + dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) + if tp_rank == 0: + train_state_name = TRAIN_STATE_NAME.format(dp_rank) + logger.info( + f"Saving train state to: {str(curr_save_dir / train_state_name)}" + ) + with open(curr_save_dir / train_state_name, "w") as f: + json.dump(train_state.state_dict(), f) + logger.info("Train state saved !") + + self.existing_saves.append(curr_save_dir) + + self.clean_up() + + if dist.is_initialized(): + dist.barrier() + return True + + @torch.no_grad() + def load( + self, + model: nn.Module, + optimizer, + train_state, + device_mesh: DeviceMesh, + path: Optional[Path] = None, + ): + dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) + # Loading tries to load the provided path, if not available the last saved step and finally from the init path + path = path or self.get_last_step_path(dp_rank=dp_rank) + # If none of those are available don't do anything + if path is None: + # If no checkpoints exist do nothing + return + + # Only load train state if it's provided, the files exist and we're not loading from init path + train_state_name = TRAIN_STATE_NAME.format(dp_rank) + logger.info("Reloading train state") + with open(path / train_state_name, "r") as f: + train_state_dict = json.load(f) + train_state.load_state_dict(train_state_dict) + logger.info("Train state reloaded") + + logger.info(f"Loading from: {str(path)}") + state_dict = self.get_state_dict( + model=model, + optimizer=optimizer, + ) + dcp.load(state_dict, checkpoint_id=path) + logger.info("State dict loaded.") + + logger.info("Reloading model and optim") + + set_state_dict( + model, + optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + ) + logger.info("Model and optim reloaded") + + @classmethod + def instantiate_and_make_dir(cls, args: CheckpointArgs): + if get_is_master(): + os.makedirs(args.path, exist_ok=True) + dist.barrier() + + return cls(args) diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml new file mode 100644 index 0000000..5f6debb --- /dev/null +++ b/bytelatent/configs/debug.yaml @@ -0,0 +1,110 @@ +# Template config, need to change dump_dir, data.root_dir and tokenizer.path +# Evals can be activated by uncommenting its config +# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest + +dump_dir: /tmp/ +name: "debug" +steps: 100_000 +probe_freq: null +seed: 777 +optim: + lr: 4e-04 + warmup: 500 + lr_min_ratio: 0.1 + clip: 10.0 + +distributed: + fsdp_type: full_shard + compile: true + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +model: + n_heads: 8 + dim: 512 + vocab_size: 260 + dim_token: 256 + patch_size: 6 + tokenization_mode: "bytes" + patching_mode: "space" + tie_local_encoder_decoder_logits: false + data_loader_patching: true + max_encoder_seq_length: 12288 + pad_to_max_length: true + patching_threshold: 3.1439168453216553 + encoder_hash_byte_group_size: [4] + encoder_hash_byte_group_vocab: 50002 + encoder_hash_byte_group_nb_functions: 3 + encoder_enable_byte_ngrams: false + cross_attn_encoder: true # assuming cross_attention is true + cross_attn_decoder: true # assuming cross_attention is true + cross_attn_window_encoder: 512 + cross_attn_window_decoder: 512 + dim_local_encoder: 256 + dim_local_decoder: 256 + cross_attn_k: 8 + cross_attn_nheads: 4 + cross_attn_all_layers_decoder: true + cross_attn_all_layers_encoder: true + cross_attn_use_flex_attention: true + cross_attn_init_by_pooling: true + log_patch_lengths: true + non_linearity: "swiglu" + use_rope: true + recompute_fc1_out: false + recompute_fc3_out: false + recompute_attn: false + custom_bwd: false + layer_ckpt: "none" + efficient_attn: "sdpa" + patch_only_encoder: false + patch_only_decoder: false + use_local_encoder_transformer: true + init_use_gaussian: true + init_use_depth: "current" + attn_bias_type: "block_causal" + alpha_depth: "disabled" + max_length: 256 + local_attention_window_len: 512 + max_seqlen: 12288 + downsampling_by_pooling: "max" + +data: + root_dir: ??? + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 64 + seq_len: 4096 + load_async: true + preprocess_dir: ??? + tokenizer_args: + name: blt + init_kwargs: + bpe_tokenizer_path: ??? + +profiling: + run: false + +checkpoint: + dump: + every: 500 + keep: 3 + eval: + every: 1000 + keep: -1 + +logging: + freq: 10 + +eval_on_gpus: 8 +eval: + dataset_dir: /checkpoint/amaia/codegen/datasets/eval + tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu + generator: + max_tokens: 65536 + dtype: bf16 + + mp_size: 1 diff --git a/bytelatent/constants.py b/bytelatent/constants.py new file mode 100644 index 0000000..341c7ff --- /dev/null +++ b/bytelatent/constants.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +from pathlib import Path + +BLT_DATA = Path(os.environ.get("BLT_DATA", "data")) diff --git a/bytelatent/data/__init__.py b/bytelatent/data/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/data/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py new file mode 100644 index 0000000..7e142e4 --- /dev/null +++ b/bytelatent/data/data_types.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import json +from dataclasses import dataclass +from typing import Any, Iterator + +import numpy as np +from pydantic import BaseModel, ConfigDict + + +class BltExample(BaseModel): + model_config = ConfigDict(extra="forbid") + sample_id: str + text: str + tokens: list[int] | None + entropies: list[float] | None + patch_lengths: list[int] | None + mask: list[bool] | None + + +class MultiChoiceState(BaseModel): + model_config = ConfigDict(extra="forbid") + root_dir: str + sources: dict[str, float] + source_to_state: dict[str, Any] + rng_state: dict[str, Any] + + +class PrefetchState(BaseModel): + model_config = ConfigDict(extra="forbid") + seq_idx: int + rng_state: dict[str, Any] + prefetch_size: int + batch_size: int + + +class BltPackTokensState(BaseModel): + model_config = ConfigDict(extra="forbid") + start_token: int + output_seq_len: int + n_views: int = 2 + + +class DataLoaderState(BaseModel): + model_config = ConfigDict(extra="forbid") + multi_choice_state: MultiChoiceState + pack_tokens_state: BltPackTokensState + prefetch_state: PrefetchState + + +BltIterator = Iterator[tuple[BltExample, DataLoaderState]] + + +class BltSequence(BaseModel): + tokens: list[int] + mask: list[bool] + patch_lengths: list[int] + + +@dataclass +class Batch: + x: np.ndarray + y: np.ndarray + mask: np.ndarray | None = None + patch_lengths: np.ndarray | None = None + ngram_ids: np.ndarray | None = None + is_final: bool = False + + def to_python_dict(self) -> dict: + x = self.x.tolist() + y = self.y.tolist() + if self.mask is None: + mask = None + else: + mask = self.mask.tolist() + if self.patch_lengths is None: + patch_lengths = None + else: + patch_lengths = self.patch_lengths.tolist() + if self.ngram_ids is None: + ngram_ids = None + else: + ngram_ids = self.ngram_ids.tolist() + return { + "x": x, + "y": y, + "mask": mask, + "patch_lengths": patch_lengths, + "ngram_ids": ngram_ids, + "is_final": self.is_final, + } + + @classmethod + def from_python_dict(cls, data: dict) -> "Batch": + x = np.array(data["x"]) + y = np.array(data["y"]) + if data["mask"] is None: + mask = None + else: + mask = np.array(data["mask"]) + if data["patch_lengths"] is None: + patch_lengths = None + else: + patch_lengths = np.array(data["patch_lengths"]) + if data["ngram_ids"] is None: + ngram_ids = None + else: + ngram_ids = np.array(data["ngram_ids"]) + return Batch( + x=x, + y=y, + mask=mask, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + is_final=data["is_final"], + ) diff --git a/bytelatent/data/iterators/__init__.py b/bytelatent/data/iterators/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/data/iterators/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py new file mode 100644 index 0000000..7fb442b --- /dev/null +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import abc +from typing import Any, Generator, Generic, TypeVar + +T = TypeVar("T") +C = TypeVar("C") + + +class StatefulIterator(Generic[T, C], abc.ABC): + + @abc.abstractmethod + def get_state(self) -> C: + pass + + @abc.abstractmethod + def create_iter(self) -> Generator[T, Any, None]: + pass + + +class IteratorState(Generic[C]): + @abc.abstractmethod + def build(self) -> StatefulIterator[T, C]: + pass diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py new file mode 100644 index 0000000..df5f023 --- /dev/null +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import re +from logging import getLogger +from pathlib import Path +from typing import Any, Generator + +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore +from pydantic import BaseModel, ConfigDict + +from bytelatent import ByteLatentError +from bytelatent.data.data_types import BltExample +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator + +logger = getLogger(__name__) + + +class ArrowFileIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + file_path: str | None + row_num: int + num_workers: int + worker_id: int + preprocess_dir: str | None + dataset_files: list[str] | None + entropy_model_name: str | None + arrow_batch_size: int = 100 + + def build(self) -> "ArrowFileIterator": + arrow_file = ArrowFileIterator( + file_path=self.file_path, + worker_id=self.worker_id, + num_workers=self.num_workers, + preprocess_dir=self.preprocess_dir, + entropy_model_name=self.entropy_model_name, + arrow_batch_size=self.arrow_batch_size, + dataset_files=self.dataset_files, + ) + if self.row_num != 0: + arrow_file._set_row_num(self.row_num) + return arrow_file + + +def shard_sort_key(file: str | Path): + match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) + shard_number = int(match.group(1)) + return shard_number + + +class ArrowFileIterator(StatefulIterator): + def __init__( + self, + *, + file_path: str | None, + worker_id: int, + num_workers: int, + preprocess_dir: str | None, + entropy_model_name: str | None, + arrow_batch_size: int, + dataset_files: list[str] | None = None, + ): + assert 0 <= worker_id < num_workers, (worker_id, num_workers) + if file_path is None and dataset_files is None: + raise ByteLatentError("file_path and dataset_files cannot both be None") + self.row_num = 0 + self.iter_id = 0 + self.batch_iterator = None + self.batch_to_consume = None + self.dataset = None + self.file_path = file_path + self.worker_id = worker_id + self.num_workers = num_workers + self.preprocess_dir = preprocess_dir + self.entropy_model_name = entropy_model_name + self.arrow_batch_size = arrow_batch_size + if dataset_files is None: + # Prepare arrow shards + jsonl_file = Path(file_path) + parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + assert parts is not None + dataset = parts.group(1) + data_dir = Path(preprocess_dir) / dataset / entropy_model_name + shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + for s in shard_files: + if not (data_dir / f"{s.name}.complete").exists(): + raise ValueError(f"Missing .complete for input file: {s}") + + shard_files = sorted(shard_files, key=shard_sort_key) + if len(shard_files) == 0: + raise ByteLatentError( + f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" + ) + self.dataset_files = [str(f) for f in shard_files] + else: + self.preprocess_dir = None + self.dataset_files = dataset_files + + def get_state(self) -> ArrowFileIteratorState: + return ArrowFileIteratorState( + file_path=self.file_path, + row_num=self.row_num, + worker_id=self.worker_id, + num_workers=self.num_workers, + preprocess_dir=self.preprocess_dir, + entropy_model_name=self.entropy_model_name, + arrow_batch_size=self.arrow_batch_size, + dataset_files=self.dataset_files, + ) + + def create_iter( + self, + ) -> Generator[BltExample, Any, None]: + if self.dataset is None: + self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + self.batch_iterator = self.dataset.to_batches( + batch_size=self.arrow_batch_size + ) + self.iter_id += 1 + if self.batch_to_consume is not None: + batch_columns: dict[str, list] = self.batch_to_consume + self.batch_to_consume = None + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + for i in range(len(sample_ids)): + out = BltExample( + sample_id=sample_ids[i], + entropies=entropies[i], + text=texts[i], + tokens=None, + mask=None, + patch_lengths=None, + ) + self.row_num += 1 + if (self.row_num - 1) % self.num_workers == self.worker_id: + yield out + + for batch in self.batch_iterator: + batch_columns = batch.to_pydict() + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + for i in range(len(sample_ids)): + out = BltExample( + sample_id=sample_ids[i], + entropies=entropies[i], + text=texts[i], + tokens=None, + mask=None, + patch_lengths=None, + ) + self.row_num += 1 + if (self.row_num - 1) % self.num_workers == self.worker_id: + yield out + + def _set_row_num(self, target_row_num: int): + logger.info( + f"Setting arrow position to {target_row_num} for {self.dataset_files}" + ) + if target_row_num is None or target_row_num == 0: + self.row_num = 0 + self.dataset = None + self.batch_iterator = None + self.batch_to_consume = None + else: + self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + self.batch_iterator = self.dataset.to_batches( + batch_size=self.arrow_batch_size + ) + curr_remaining = target_row_num + for batch in self.batch_iterator: + if len(batch) > curr_remaining: + batch_columns: dict[str, list] = batch.to_pydict() + batch_columns["sample_id"] = batch_columns["sample_id"][ + curr_remaining: + ] + batch_columns["entropies"] = batch_columns["entropies"][ + curr_remaining: + ] + batch_columns["text"] = batch_columns["text"][curr_remaining:] + self.batch_to_consume = batch_columns + break + elif len(batch) == curr_remaining: + # We are exactly at the end of the batch, + # so the next batch is the right spot + break + else: + curr_remaining -= len(batch) + self.row_num = target_row_num + logger.info( + f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + ) + + +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN +): + dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks diff --git a/bytelatent/data/iterators/looping_iterator.py b/bytelatent/data/iterators/looping_iterator.py new file mode 100644 index 0000000..2eff38c --- /dev/null +++ b/bytelatent/data/iterators/looping_iterator.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from pydantic import BaseModel + +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.arrow_iterator import ( + ArrowFileIterator, + ArrowFileIteratorState, +) + + +class LoopingIteratorState(BaseModel, IteratorState): + file_iterator_state: ArrowFileIteratorState + epoch: int + + def build(self) -> "LoopingIterator": + return LoopingIterator( + file_iterator=self.file_iterator_state.build(), + epoch=self.epoch, + ) + + +class LoopingIterator(StatefulIterator): + def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1): + self.file_iterator = file_iterator + self.epoch = epoch + + def get_state(self): + return LoopingIteratorState( + file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch + ) + + def create_iter(self): + while True: + self.epoch += 1 + iterator = self.file_iterator.create_iter() + yield from iterator diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py new file mode 100644 index 0000000..f17ca6e --- /dev/null +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import json +import logging +import multiprocessing as mp +from multiprocessing.synchronize import Event as EventClass +from queue import Empty, Full + +import numpy as np +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.data_types import Batch +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.packing_iterator import PackingIteratorState + +logger = logging.getLogger() + + +class MultiprocessIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + base_iterator_state: PackingIteratorState + n_batches_to_prefetch: int + serialized_prefetch_buffer: str + + def build(self): + base_iterator = self.base_iterator_state.build() + data = json.loads(self.serialized_prefetch_buffer) + prefetch_buffer = [Batch.from_python_dict(item) for item in data] + return MultiprocessIterator( + base_iterator, + n_batches_to_prefetch=self.n_batches_to_prefetch, + prefetch_buffer=prefetch_buffer, + ) + + +def start_work_from_state( + batch_queue: mp.Queue, + state_queue: mp.Queue, + stop_event: EventClass, + state_dumped_event: EventClass, + state: IteratorState, +): + logging.info("Worker thread: Starting base_iterator work") + stateful_iterator = state.build() + iterator = stateful_iterator.create_iter() + for item in iterator: + while not stop_event.is_set(): + try: + # Attempt to put on queue or timeout to try again (maybe main thread is busy) + batch_queue.put(item, timeout=0.1) + # On success, stop trying + break + except Full: + pass + if stop_event.is_set(): + # Signal the end of output, this ensures that even if the queue takes a while to + # buffer, that the main thread receives everything (and tosses this fake batch) + logging.info( + "Worker thread: Stop event detected, outputting is_final=True batch" + ) + batch_queue.put( + Batch( + x=np.zeros((1, 1)), + y=np.zeros((1, 1)), + is_final=True, + mask=None, + patch_lengths=None, + ngram_ids=None, + ) + ) + break + + try: + logging.info("Worker thread: outputting state") + state_queue.put(iterator.get_state(), timeout=1) + logging.info("Worker thread: state dump complete") + state_dumped_event.set() + logging.info("Worker thread: set state_dump_event") + except Full: + raise ValueError( + "Attempted to dump state into the state queue, but it was full" + ) + + +class MultiprocessIterator(StatefulIterator): + """ + Design sketch of the multiprocess iterator: + + Given the base_iterator, the only thing we do with this is call get_state() + so that we can pass that through to the background worker process. + + The background process will receive this, rebuild the iterator, then start yielding from it. + + However, in order to implement MultiprocessIterator.get_state(), we need to be able to accurately get + (1) the state of the iterator in the worker process + (2) the currently buffered items in the Queue + + To do this, we use: + - batch_queue: This is the prefetch buffer the worker yields to and the main loop yields from + - state_queue: This size 1 queue will be how the worker sends the iterator state once it has halted iterating. + It must hold the state in addition to the last batch, if the queue was full at the time the stop event is sent. + - stop_iterating_event: Once this is issued from the main loop, the worker will stop iterating and enter cleanup. + During cleanup, the iterator will send the state of the current iterator to the main loop, + in addition to possibly the last batch if the batch_queue was full at the time + - state_dumped_event: When the main loop issues the stop_iterating_event, it will wait until the state_dumped_event to attempt + to get state from the state_queue. It must do this since the worker may take some time to create and send the state. + Once received by the main loop, the main loop can safely store the Queue (plus maybe the last batch) as the prefetch buffer, + get the worker iterator's state, and terminate the background process + delete associated objects. + + At this point, calling create_iter() again will bootstrap everything from the stored state and the old iterator will throw an error + since it will not iterate anymore (so the caller must call create_iter() again to get a python iterator). + + """ + + def __init__( + self, + base_iterator: StatefulIterator, + *, + n_batches_to_prefetch: int, + prefetch_buffer: list | None = None + ): + self.base_iterator = base_iterator + self.n_batches_to_prefetch = n_batches_to_prefetch + if prefetch_buffer is None: + prefetch_buffer = [] + self.prefetch_buffer = prefetch_buffer + self.batch_queue = None + self.state_queue = None + self.producer = None + self.stop_iterating_event = None + self.state_dumped_event = None + + def get_state(self) -> MultiprocessIteratorState: + """ + This is slightly unusual in effectively destroying the current iterator, its necessary + to halt the background process and allow it to write the state to the main loop + in order to not lose data + """ + if self.producer is None: + serialized_prefetch_buffer = json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ) + return MultiprocessIteratorState( + base_iterator_state=self.base_iterator.get_state(), + n_batches_to_prefetch=self.n_batches_to_prefetch, + serialized_prefetch_buffer=serialized_prefetch_buffer, + ) + else: + logging.info("Main thread: Sending stop iteration event") + self.stop_iterating_event.set() + logging.info("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + self.prefetch_buffer = [] + final_batch_received = False + while True: + try: + batch = self.batch_queue.get(timeout=1) + if batch.is_final: + final_batch_received = True + break + self.prefetch_buffer.append(batch) + except Empty: + logging.warning("Main thread: batch_queue is abnormally empty") + assert final_batch_received + + try: + base_iterator_state = self.state_queue.get(timeout=1) + assert isinstance(base_iterator_state, IteratorState) + except Empty: + raise ValueError( + "Attempted to get the state, but it was unexpectantly missing" + ) + + self.base_iterator = base_iterator_state.build() + self.producer.close() + self.producer = None + self.batch_queue = None + self.state_queue = None + self.stop_iterating_event = None + self.state_dumped_event = None + + return MultiprocessIteratorState( + base_iterator_state=self.base_iterator.get_state(), + n_batches_to_prefetch=self.n_batches_to_prefetch, + serialized_prefetch_buffer=json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ), + ) + + def create_iter(self): + logging.info("Main thread: Creating MP iterator") + # First yield from the stored prefetch buffer. + if self.prefetch_buffer is not None: + while len(self.prefetch_buffer) > 0: + item = self.prefetch_buffer.pop(0) + yield item + self.prefetch_buffer = None + + assert ( + self.producer is None + ), "Cannot create two parallel iterators at once, call get_state() then remake to have two." + + # using mp context manager avoids excessive CPU loading + ctx = mp.get_context("forkserver") + self.batch_queue = ctx.Manager().Queue(maxsize=self.n_batches_to_prefetch) + + # We should only ever one state, which is output at the detection of a stop event + self.state_queue = ctx.Manager().Queue(maxsize=1) + + self.stop_iterating_event = ctx.Event() + self.state_dumped_event = ctx.Event() + + self.producer = mp.Process( + name="blt_data_loader", + target=start_work_from_state, + args=( + self.batch_queue, + self.state_queue, + self.stop_iterating_event, + self.state_dumped_event, + self.base_iterator.get_state(), + ), + ) + logger.info("Async dataloader started") + self.producer.start() + + while True: + if self.producer.exitcode is not None: + raise RuntimeError( + "Data loader quit unexpectedly, real error has been raised previously" + ) + try: + batch = self.batch_queue.get(timeout=0.1) + assert isinstance(batch, Batch) + assert ( + not batch.is_final + ), "is_final should only be used during get_state() being called" + yield batch + except Empty: + pass + if self.producer is None: + raise ValueError( + "Attempted to call this iterator after calling get_state(). You must call create_iter() to make a new iterator instead." + ) diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py new file mode 100644 index 0000000..361fc03 --- /dev/null +++ b/bytelatent/data/iterators/packing_iterator.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.data_types import Batch, BltSequence +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState + + +class PackingArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + batch_size: int + seq_len: int + pad_id: int + max_length: int | None + pad_to_max_length: bool + enable_byte_ngrams: bool + + +class PackingIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + sequence_iterator_state: SamplingIteratorState + packing_args: PackingArgs + + def build(self) -> "PackingIterator": + return PackingIterator( + sequence_iterator=self.sequence_iterator_state.build(), + packing_args=self.packing_args, + ) + + +def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]): + assert len(mask_seqs) == bs + lens = [len(m) for m in mask_seqs] + if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens): + return None + assert slen == max(lens) - 1 + mask = np.zeros((bs, slen), dtype=bool) + for i, m in enumerate(mask_seqs): + if m is None: + print( + "Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function." + ) + raise NotImplementedError + mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:] + return mask + + +def truncate_batch( + batch: Batch, + max_length: int, + pad_id: int, + pad_to_max_length: bool = False, + *, + enable_byte_ngrams: bool, +): + """ + Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts + and fixing the batch.mask. + + batch.patch_lengths has unchanged shape + x,y, and mask may reduce in size + """ + if batch.patch_lengths is None: + return batch + + seq_lengths = batch.patch_lengths.sum(axis=1) + max_length_adj = max_length + 1 + if np.any(seq_lengths > max_length_adj): + for i in range(batch.x.shape[0]): + if seq_lengths[i] > max_length_adj: + # Find id of patch that tips over max_length + 1 + count, j = 0, 0 + while count + batch.patch_lengths[i, j] <= max_length_adj: + count += batch.patch_lengths[i, j] + j += 1 + # Edit the batch + assert j < batch.patch_lengths.shape[1] + batch.x[i, max_length:] = pad_id + batch.y[i, max_length:] = pad_id + if batch.mask is not None: + batch.mask[i, max_length:] = False + batch.patch_lengths[i, j:] = 0 + batch.patch_lengths[i, j] = max_length_adj - count + + # Truncate if necessary. + if max_length < batch.x.shape[1]: + batch.x = batch.x[:, :max_length] + batch.y = batch.y[:, :max_length] + if batch.mask is not None: + batch.mask = batch.mask[:, :max_length] + + # Right pad to max_length if necessary + elif pad_to_max_length: + if batch.x.shape[1] < max_length: + # NOTE: this has to be done on an actual patch. + non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1 + non_zero_indices = np.maximum(0, non_zero_indices) + batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += ( + max_length - batch.x.shape[1] + ) + # TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader. + x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype) + x[:, : batch.x.shape[1]] = batch.x + batch.x = x + if batch.y.shape[1] < max_length: + y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype) + y[:, : batch.y.shape[1]] = batch.y + batch.y = y + if batch.mask is not None and batch.mask.shape[1] < max_length: + mask = np.full( + (batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype + ) + mask[:, : batch.mask.shape[1]] = batch.mask + batch.mask = mask + + assert batch.x.shape[1] <= max_length + assert batch.y.shape[1] <= max_length + assert batch.mask is None or batch.mask.shape[1] <= max_length + assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0) + if pad_to_max_length: + assert batch.x.shape[1] == max_length + assert batch.y.shape[1] == max_length + assert batch.mask is None or batch.mask.shape[1] == max_length + if enable_byte_ngrams: + raise NotImplementedError() + # (num_ngram, batch_size, seq_len) + ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x)) + assert ngram_ids.shape[2] == batch.x.shape[1] + else: + ngram_ids = None + batch.ngram_ids = ngram_ids + + +class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): + def __init__( + self, + sequence_iterator: StatefulIterator[BltSequence, Any], + *, + packing_args: PackingArgs, + ): + self.sequence_iterator = sequence_iterator + self.packing_args = packing_args + + def get_state(self): + return PackingIteratorState( + sequence_iterator_state=self.sequence_iterator.get_state(), + packing_args=self.packing_args, + ) + + def create_iter(self): + sequence_iter = self.sequence_iterator.create_iter() + batch_size = self.packing_args.batch_size + pad_id = self.packing_args.pad_id + seq_len = self.packing_args.seq_len + pad_to_max_length = self.packing_args.pad_to_max_length + enable_byte_ngrams = self.packing_args.enable_byte_ngrams + max_length = self.packing_args.max_length + while True: + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + patch_lengths: list[list[int]] = [] + + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + _patch_lengths = sequence.patch_lengths + assert len(sequence.patch_lengths) == self.packing_args.seq_len + last_patch_length = 0 + if _patch_lengths[0] > 1: + last_patch_length = _patch_lengths[-1] + _patch_lengths[0] -= 1 + _patch_lengths = [1] + _patch_lengths[:-1] + tokens.append(_tokens[: len(_tokens) - last_patch_length]) + masks.append(_mask[: len(_mask) - last_patch_length]) + patch_lengths.append(_patch_lengths) + + x_patch_lengths = np.array(patch_lengths) + # pad batch to same length + tok_seq_len = max([len(toks) for toks in tokens]) - 1 + x = np.full((batch_size, tok_seq_len), fill_value=pad_id) + y = np.full((batch_size, tok_seq_len), fill_value=pad_id) + + for i, tok_seq in enumerate(tokens): + x[i, : len(tok_seq) - 1] = tok_seq[:-1] + y[i, : len(tok_seq) - 1] = tok_seq[1:] + # Adjust patch lengths to match x + x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1) + + assert x_patch_lengths.shape == (batch_size, seq_len) + + if enable_byte_ngrams: + raise NotImplementedError() + else: + ngram_ids = None + + batch = Batch( + x=x, + y=y, + patch_lengths=x_patch_lengths, + ngram_ids=ngram_ids, + mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), + ) + assert ( + x_patch_lengths.sum() == x.size + batch_size + ), f"{x_patch_lengths.sum()} != {x.size + batch_size}" + assert ( + batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() + ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + assert np.all( + x_patch_lengths[:, 0] == 1 + ), f"first patch should always be 1, {x_patch_lengths[:, 0]}" + # cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024) + # cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024 + # print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}") + truncate_batch( + batch, + max_length=max_length, + pad_id=pad_id, + pad_to_max_length=pad_to_max_length, + enable_byte_ngrams=enable_byte_ngrams, + ) + yield batch diff --git a/bytelatent/data/iterators/preprocess_iterator.py b/bytelatent/data/iterators/preprocess_iterator.py new file mode 100644 index 0000000..8eeba41 --- /dev/null +++ b/bytelatent/data/iterators/preprocess_iterator.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Any, Generator + +import torch +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.data_types import BltExample +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.arrow_iterator import ( + ArrowFileIterator, + ArrowFileIteratorState, +) +from bytelatent.data.iterators.looping_iterator import LoopingIteratorState +from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +class PreprocessIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState + add_tokens: bool + add_patches: bool + tokenizer_args: TokenizerArgs + patcher_args: PatcherArgs + + def build(self): + arrow_iterator = self.arrow_file_iterator_state.build() + return PreprocessIterator( + arrow_iterator, + patcher_args=self.patcher_args, + tokenizer_args=self.tokenizer_args, + add_tokens=self.add_tokens, + add_patches=self.add_patches, + ) + + +class PreprocessIterator(StatefulIterator): + """ + Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require + preprocessing like tokenization and patching + """ + + def __init__( + self, + arrow_iterator: ArrowFileIterator, + *, + patcher_args: PatcherArgs, + tokenizer_args: TokenizerArgs, + add_tokens: bool = True, + add_patches: bool = True, + ): + self.arrow_iterator = arrow_iterator + self.tokenizer_args = tokenizer_args + self.patcher_args = patcher_args + self.add_tokens = add_tokens + self.add_patches = add_patches + self.tokenizer: BltTokenizer | None = None + self.patcher: Patcher | None = None + + def get_state(self) -> PreprocessIteratorState: + """ + The only state to maintain here is from arrow, there + isn't any internal state on this iterator. + """ + return PreprocessIteratorState( + arrow_file_iterator_state=self.arrow_iterator.get_state(), + tokenizer_args=self.tokenizer_args, + patcher_args=self.patcher_args, + add_tokens=self.add_tokens, + add_patches=self.add_patches, + ) + + def create_iter(self) -> Generator[BltExample, Any, None]: + if self.tokenizer is None and self.add_tokens: + self.tokenizer = self.tokenizer_args.build() + if self.patcher is None and self.add_patches: + self.patcher = self.patcher_args.build() + + example_iter = self.arrow_iterator.create_iter() + for example in example_iter: + if self.add_tokens: + tokens = self.tokenizer.encode(example.text) + else: + tokens = example.tokens + if ( + self.patcher is not None + and self.patcher.patching_mode == PatchingModeEnum.entropy + ): + assert ( + example.entropies is not None + ), "For patching, entropies cannot be None" + entropies = torch.tensor(example.entropies).unsqueeze(0) + else: + entropies = None + if self.patcher is None: + patch_lengths = None + else: + patch_lengths = self.patcher.patch( + torch.tensor(tokens).unsqueeze(0), + include_next_token=False, + entropies=entropies, + )[0][0].tolist() + yield BltExample( + sample_id=example.sample_id, + text=example.text, + tokens=tokens, + mask=[True] * len(tokens), + patch_lengths=patch_lengths, + entropies=example.entropies, + ) diff --git a/bytelatent/data/iterators/sampling_iterator.py b/bytelatent/data/iterators/sampling_iterator.py new file mode 100644 index 0000000..6474bf6 --- /dev/null +++ b/bytelatent/data/iterators/sampling_iterator.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState + + +class SamplingIteratorState(BaseModel): + model_config = ConfigDict(extra="forbid") + rng_state: dict[str, Any] + source_to_weight: dict[str, float] + source_to_iterator_state: dict[str, SequenceIteratorState] + + def build(self) -> "SamplingIterator": + return SamplingIterator( + rng_state=self.rng_state, + source_to_weight=self.source_to_weight, + source_to_iterator={ + source: state.build() + for source, state in self.source_to_iterator_state.items() + }, + ) + + +class SamplingIterator(StatefulIterator): + def __init__( + self, + *, + rng_state: dict[str, Any], + source_to_weight: dict[str, float], + source_to_iterator: dict[str, StatefulIterator], + ): + self.rng = np.random.default_rng() + self.rng.bit_generator.state = rng_state + self.source_to_weight = source_to_weight + self.source_to_iterator = source_to_iterator + + def get_state(self) -> SamplingIteratorState: + return SamplingIteratorState( + rng_state=self.rng.bit_generator.state, + source_to_weight=self.source_to_weight, + source_to_iterator_state={ + source: iterator.get_state() + for source, iterator in self.source_to_iterator.items() + }, + ) + + def create_iter(self): + n_sources = len(self.source_to_weight) + possible_sources = [] + weights = [] + for source, w in self.source_to_weight.items(): + possible_sources.append(source) + weights.append(w) + + source_to_python_iter = { + source: self.source_to_iterator[source].create_iter() + for source in possible_sources + } + while True: + norm_weights = np.array(weights) / np.array(weights).sum() + source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)] + yield next(source_to_python_iter[source_choice]) diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py new file mode 100644 index 0000000..14e3747 --- /dev/null +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from logging import getLogger +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict + +from bytelatent.data.data_types import BltSequence +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.preprocess_iterator import ( + PreprocessIterator, + PreprocessIteratorState, +) + +logger = getLogger() + + +class SequencePackingArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + output_seq_len: int + buffer_size: int + + +class SequenceIteratorState(BaseModel, IteratorState): + model_config = ConfigDict(extra="forbid") + sequence_packing_args: SequencePackingArgs + preprocess_iterator_state: PreprocessIteratorState + rng_state: dict[str, Any] + + def build(self): + preprocess_iterator = self.preprocess_iterator_state.build() + return SequenceIterator( + preprocess_iterator, + sequence_packing_args=self.sequence_packing_args, + rng_state=self.rng_state, + ) + + +class SequenceIterator(StatefulIterator): + def __init__( + self, + preprocess_iterator: PreprocessIterator, + *, + rng_state: dict[str, Any], + sequence_packing_args: SequencePackingArgs, + ): + self.preprocess_iterator = preprocess_iterator + self.sequence_packing_args = sequence_packing_args + self.output_seq_len = sequence_packing_args.output_seq_len + self.buffer_size = sequence_packing_args.buffer_size + self.rng = np.random.default_rng() + self.rng.bit_generator.state = rng_state + + def get_state(self): + # TODO: need to also perist the current shuffle buffer + return SequenceIteratorState( + sequence_packing_args=self.sequence_packing_args, + preprocess_iterator_state=self.preprocess_iterator.get_state(), + rng_state=self.rng.bit_generator.state, + ) + + def create_iter(self): + example_iter = self.preprocess_iterator.create_iter() + n_buffer_patches = self.buffer_size * self.output_seq_len + + patch_lengths: list[int] = [] + tokens: list[int] = [] + mask: list[bool] = [] + first = True + for example in example_iter: + assert example.tokens is not None + assert example.mask is not None + assert example.patch_lengths is not None + assert len(example.tokens) != 0 + assert len(example.mask) != 0 + assert len(example.tokens) == len(example.mask) + assert len(example.tokens) == sum(example.patch_lengths) + + tokens.extend(example.tokens) + mask.extend(example.mask) + patch_lengths.extend(example.patch_lengths) + + while len(patch_lengths) >= n_buffer_patches: + if first: + first = False + logger.info("First buffer complete") + + x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape( + self.buffer_size, self.output_seq_len + ) + seq_tokens = [] + seq_mask = [] + start_id = 0 + # We fix the number of patches and therefore global steps per batch + # so we have a variable number of tokens we need to account for + for num_tokens in x_patches.sum(axis=-1): + seq_tokens.append(tokens[start_id : start_id + num_tokens]) + seq_mask.append(mask[start_id : start_id + num_tokens]) + start_id += num_tokens + + assert start_id == x_patches.sum() + + # Remove what we just added from the buffer + patch_lengths = patch_lengths[n_buffer_patches:] + tokens = tokens[x_patches.sum() :] + mask = mask[x_patches.sum() :] + + seq_patch_lengths: list[list[int]] = x_patches.tolist() + assert len(seq_patch_lengths) == self.buffer_size + for idx in self.rng.permutation(len(seq_patch_lengths)): + assert len(seq_patch_lengths[idx]) == self.output_seq_len + assert ( + sum(seq_patch_lengths[idx]) + == len(seq_tokens[idx]) + == len(seq_mask[idx]) + ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" + assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=seq_patch_lengths[idx], + ) diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py new file mode 100644 index 0000000..4266427 --- /dev/null +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import numpy as np +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + +from bytelatent.constants import BLT_DATA +from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState + +ENTROPY_MODEL = "transformer_100m" +ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") +ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow") + + +def test_basic_arrow_file(): + dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow") + n_head = 1000 + head_df = dataset.head(n_head).to_pandas() + + initial_state = ArrowFileIteratorState( + file_path=None, + num_workers=1, + worker_id=0, + preprocess_dir=None, + entropy_model_name=ENTROPY_MODEL, + dataset_files=[ARROW_TEST_DATA_1], + row_num=0, + arrow_batch_size=100, + ) + arrow_file = initial_state.build() + start_state = arrow_file.get_state() + assert start_state.row_num == initial_state.row_num + + sample_id = None + for example in arrow_file.create_iter(): + sample_id = example.sample_id + assert head_df.iloc[0]["sample_id"] == sample_id + break + + assert arrow_file.get_state().row_num == 1 + arrow_file = initial_state.build() + for example in arrow_file.create_iter(): + assert example.sample_id == sample_id + assert head_df.iloc[0]["sample_id"] == sample_id + break + + # Test resume far enough in to be past the batch size of 100 + resumed_state = ArrowFileIteratorState( + file_path=None, + num_workers=1, + worker_id=0, + preprocess_dir=None, + entropy_model_name=ENTROPY_MODEL, + dataset_files=[ARROW_TEST_DATA_1], + row_num=251, + arrow_batch_size=100, + ) + arrow_file = resumed_state.build() + for example in arrow_file.create_iter(): + assert example.sample_id == head_df.iloc[251]["sample_id"] + assert arrow_file.get_state().row_num == 252 + break + + world_rank = 1 + world_size = 4 + # Test World Size and Rank + rank_state = ArrowFileIteratorState( + file_path=None, + num_workers=world_size, + worker_id=world_rank, + preprocess_dir=None, + entropy_model_name=ENTROPY_MODEL, + dataset_files=[ARROW_TEST_DATA_1], + row_num=0, + arrow_batch_size=100, + ) + arrow_file = rank_state.build() + expected_ids = [] + for i in range(n_head): + if i % world_size == world_rank: + expected_ids.append(head_df.iloc[i]["sample_id"]) + print(len(expected_ids)) + i = 0 + for example in arrow_file.create_iter(): + assert example.sample_id == expected_ids[i] + i += 1 + if i >= len(expected_ids): + break diff --git a/bytelatent/data/iterators/test_iters.py b/bytelatent/data/iterators/test_iters.py new file mode 100644 index 0000000..9bc9d59 --- /dev/null +++ b/bytelatent/data/iterators/test_iters.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import pandas as pd +from pydantic import BaseModel + +from bytelatent.constants import BLT_DATA +from bytelatent.data.data_types import BltExample +from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator +from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +class BltTestIteratorState(BaseModel, IteratorState): + position: int + total: int + + def build(self): + blt_iter = BltTestIteratorState(total=self.total) + blt_iter.position = self.position + return blt_iter + + +class BltTestIterator(StatefulIterator): + def __init__(self, total: int): + self.position = 0 + self.total = total + + def get_state(self): + return BltTestIteratorState(position=self.position, total=self.total) + + def create_iter(self): + for i in range(self.total): + self.position += 1 + yield BltExample( + sample_id=f"test_{i}", + text=f"This is some test {i} text.", + tokens=None, + mask=None, + entropies=None, + patch_lengths=None, + ) + + +class BltTestWithEntropiesIteratorState(BaseModel, IteratorState): + position: int + total: int + + def build(self): + blt_iter = BltTestWithEntropiesIteratorState(total=self.total) + blt_iter.position = self.position + return blt_iter + + +class BltTestWithEntropiesIterator(StatefulIterator): + def __init__(self, total: int): + self.position = 0 + self.total = total + + def get_state(self): + return BltTestIteratorState(position=self.position, total=self.total) + + def create_iter(self): + text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin." + df = pd.read_json("fixtures/tokens_with_entropies.json") + tokens = df["token_ids"].tolist() + entropies = df["entropies"].tolist() + # BOS and EOS + assert len(tokens) == len(text) + 2 + for i in range(self.total): + self.position += 1 + yield BltExample( + sample_id=f"test_{i}", + text=text, + tokens=tokens, + mask=[True] * len(tokens), + entropies=entropies, + patch_lengths=None, + ) + + +def test_preprocess_iter(): + total = 3 + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + for mode in [ + PatchingModeEnum.bpe, + PatchingModeEnum.space, + ]: + data_it = BltTestIterator(total) + patcher_args = PatcherArgs(patching_mode=mode) + example_it = PreprocessIterator( + data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args + ) + count = 0 + for example in example_it.create_iter(): + assert isinstance(example.tokens, list) + assert isinstance(example.tokens[0], int) + # BOS and EOS + assert len(example.tokens) == len(example.text) + 2 + assert example.mask is not None + assert len(example.tokens) == len(example.mask) + count += 1 + + assert count == total + + +def test_non_entropy_patch_iter(): + total = 3 + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + for mode in [ + PatchingModeEnum.bpe, + PatchingModeEnum.space, + ]: + patcher_args = PatcherArgs(patching_mode=mode) + data_it = BltTestIterator(total) + example_it = PreprocessIterator( + data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args + ) + + count = 0 + for example in example_it.create_iter(): + assert isinstance(example.patch_lengths, list) + assert isinstance(example.patch_lengths[0], int) + assert len(example.tokens) == sum(example.patch_lengths) + count += 1 + + assert count == total + + +def test_entropy_patch_iter(): + total = 2 + patcher_args = PatcherArgs( + patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627 + ) + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + data_it = BltTestWithEntropiesIterator(total) + example_it = PreprocessIterator( + data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args + ) + + count = 0 + for example in example_it.create_iter(): + assert isinstance(example.patch_lengths, list) + assert isinstance(example.patch_lengths[0], int) + assert len(example.tokens) == sum(example.patch_lengths) + count += 1 + + assert count == total diff --git a/bytelatent/data/ngram_processor.py b/bytelatent/data/ngram_processor.py new file mode 100644 index 0000000..2498183 --- /dev/null +++ b/bytelatent/data/ngram_processor.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import pickle +from pathlib import Path + +import numpy as np + +from bytelatent import ByteLatentError + +LOOKUP_OFFSET = 4 + + +def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1): + """ + Wrapper function for applying the lookup table to each n-gram. + + :param ngram: Array of numbers representing an n-gram. + :param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs. + :param lookup_offset: Offset to add to the lookup result. + :return: The value associated with the n-gram tuple in the dictionary, or None if not found. + """ + + def apply_lookup_table(ngram): + """ + Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary. + + :param ngram: Array of numbers representing an n-gram. + :return: The value associated with the n-gram tuple in the dictionary, or None if not found. + """ + # Convert the n-gram to a tuple + ngram_tuple = tuple(ngram) + + if ngram_tuple not in ngram_to_idx: + return 0 + else: + return ngram_to_idx[ngram_tuple] + lookup_offset + + return apply_lookup_table + + +def get_byte_ngrams_ids( + byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0 +): + """ + Generate n-grams from a 2D numpy array. + + :param n: The length of each n-gram. + :param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams. + :return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET. + """ + num_rows, num_cols = byte_array.shape + + # Create an array to hold the padded version of the original array + padded_array = np.pad( + byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value + ) + + # Use stride tricks to avoid explicit looping + strided = np.lib.stride_tricks.as_strided + shape = (num_rows, num_cols, n) + strides = padded_array.strides[:2] + (padded_array.strides[1],) + ngrams = strided(padded_array, shape=shape, strides=strides) + + ngram_ids = np.apply_along_axis( + apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams + ) + assert ngram_ids.shape == byte_array.shape + return ngram_ids + + +def reload_tables( + ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET +) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]: + """ + Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram, + only load up to the max specified size. Return the actual number of ngrams taken per ngram size. + """ + idx_to_ngram_tables = {} + ngram_to_idx_tables = {} + vocab_sizes = {} + for ngram, size in ngram_to_size.items(): + with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f: + # These are already sorted by count + # Value: tuple of: count, ngram, dataset + ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[ + "counts" + ] + table = [ngram for ngram, _ in ngram_data][:size] + if len(table) != size: + raise ValueError( + f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}" + ) + ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)} + actual_size = len(table) + idx_to_ngram_tables[ngram] = table + ngram_to_idx_tables[ngram] = ngram_to_idx + vocab_sizes[ngram] = actual_size + offset + return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes + + +def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: + if ngram_to_size_str is None: + return None + ngram_to_size = {} + for entry in ngram_to_size_str.split(","): + ngram, size = entry.split(":") + ngram = int(ngram) + size = int(size) + ngram_to_size[ngram] = size + return ngram_to_size + + +class NgramProcessor: + def __init__( + self, + ngram_table_dir: str | None = None, + ngram_to_size: dict[int, int] | None = None, + ): + if ngram_table_dir is None or ngram_to_size is None: + raise ByteLatentError( + "ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True" + ) + ( + self.ngram_to_idx_tables, + self.idx_to_ngram_tables, + self.ngram_vocab_sizes, + ) = reload_tables(ngram_table_dir, ngram_to_size) + # Lowest to highest ngram + self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys())) + # Although the model might not use all the ngrams, we need the tokenizer + # to produce ngram_ids such that index zero is the 2-gram, later on in + # src.model.megabyte.Megabyte.forward + assert self.ngram_sizes[0] == 2 + + def encode_single_ngram_table(self, data: np.ndarray, n: int): + """ + Return the n-grams of the input data for a given n + numpy array with ids of shape data.shape + """ + return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0) + + def encode_token_ngrams(self, data: np.ndarray): + """ + Return the n-grams of the input data. + output shape: [ids with data.shape for n in self.ngram_sizes] + """ + return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes] diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py new file mode 100644 index 0000000..ede8b06 --- /dev/null +++ b/bytelatent/data/patcher.py @@ -0,0 +1,609 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import math +import time +from collections import defaultdict +from enum import Enum + +import torch +from pydantic import BaseModel +from torch.nn import functional as F + +from bytelatent.distributed import get_local_rank +from bytelatent.entropy_model import load_entropy_model + +# from src.slurm import get_local_rank +from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET +from bytelatent.tokenizers.constants import BPE_ID, OFFSET + + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + + +class PatcherArgs(BaseModel): + patching_mode: PatchingModeEnum = PatchingModeEnum.entropy + patching_device: str = "cuda" + entropy_model_checkpoint_dir: str | None = None + realtime_patching: bool = False + threshold: float = 1.335442066192627 + threshold_add: float | None = None + max_patch_length: int | None = None + patch_size: float = 4.5 + patching_batch_size: int = 1 + data_loader_patching: bool = False + device: str = "cuda" + monotonicity: bool = False + log_time: bool = False + + def build(self) -> "Patcher": + return Patcher(self) + + +def entropy(scores): + """ + scores: [bs, seq_len, vocab] + returns [bs, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + +def calculate_entropies( + tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None +): + """ + tokens: 2D tensor of shape [batch_size, seq_len] + Return 2D tensor of shape [batch_size, seq_len] with entropies for each token. + + Splits the tokens into chunks of size max_length and calculates entropies for each chunk. + Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. + """ + with torch.no_grad(): + entropies = [] + max_length = getattr(entropy_model, "max_length", 8192) + batch_numel = max_length * patching_batch_size + splits = torch.split(tokens.flatten(), batch_numel) + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros( + pad_size, dtype=split.dtype, device=split.device, requires_grad=False + ) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + assert torch.all(split >= 0) and torch.all(split < 260) + pred, _ = entropy_model(split) + pred = pred.reshape(-1, pred.shape[-1])[ + : split.numel() - pad_size, : + ] # [batch_size * seq_len, vocab] + pred_entropies = entropy(pred) + entropies.append(pred_entropies) + + entropies = torch.cat(entropies, dim=0) + entropies = entropies.reshape(tokens.shape) + return entropies + + +def patch_start_mask_from_entropy_with_monotonicity(entropies, t): + """ + entropies: [bs, seq_len] torch tensor of entropies + t: threshold + returns [bs, seq_len] mask where True indicates the start of a patch + """ + bs, seq_len = entropies.shape + mask = torch.zeros_like(entropies, dtype=torch.bool) + mask[:, 0] = True + + # Calculate differences between consecutive elements along the sequence length + differences = entropies[:, 1:] - entropies[:, :-1] + + # Calculate conditions for all elements except the first one in each sequence + condition = differences > t + + # Update the mask based on the condition + mask[:, 1:] = condition + + return mask + + +def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0): + """ + entropies: [bs, seq_len] torch tensor of entropies + t: threshold + returns [bs, seq_len] mask where True indicates the start of a patch + """ + bs, seq_len = entropies.shape + mask = torch.zeros_like(entropies, dtype=torch.bool) + mask[:, 0] = True + + # Calculate differences between consecutive elements along the sequence length + differences = entropies[:, 1:] - entropies[:, :-1] + + # Calculate conditions for all elements except the first one in each sequence + condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1]) + + # Update the mask based on the condition + mask[:, 1:] = condition + + return mask + + +def patch_start_ids_from_patch_start_mask(patch_start_mask): + bs, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = ( + torch.arange(trunc_seq_len, device=patch_start_mask.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + extra_patch_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat( + (patch_start_mask, ~patch_start_mask), dim=1 + ) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( + bs, trunc_seq_len + )[:, :max_patches] + return patch_start_ids + + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + + +def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + + +def find_space_patch_start_ids(tokens): + bs, seq_len = tokens.shape + tokens_no_offset = tokens - OFFSET + patch_end_mask = ( + (tokens_no_offset < ord("0")) + | ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A"))) + | ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a"))) + | ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000)) + | (0b1100_0000 <= tokens_no_offset) + ) + patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not() + patch_end_mask |= tokens < OFFSET + + patch_start_mask = torch.cat( + [ + torch.tensor([1, 1], device=tokens.device, dtype=torch.bool) + .unsqueeze(0) + .repeat(bs, 1), + patch_end_mask[:, 1:], + ], + dim=1, + ) + max_patches = patch_start_mask.sum(dim=1).max() + + patch_ids = ( + torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1) + ) + extra_patch_ids = torch.full( + (bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) + + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[ + :, :max_patches + ] + return patch_start_ids + + +def to_device(entropy_model, device=None): + if device == "cuda": + rank = get_local_rank() + device = f"cuda:{rank}" + entropy_model = entropy_model.to(device) + return entropy_model, device + + +def model_pred_to_bpe_patching_pred(pred): + _, indices = torch.max(pred, dim=1) + return indices == BPE_ID + + +def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None): + assert tokens.device == torch.device( + "cpu" + ), f"{tokens.device} != cpu expects tokens to be on cpu" + with torch.no_grad(): + bpe_patcher_device, device = to_device( + bpe_patcher, device + ) # Get entropy model to right rank device. + bpe_patching_mask = [] + max_length = getattr(bpe_patcher, "max_length", 8192) + batch_numel = max_length * patching_batch_size + splits = torch.split(tokens.flatten(), batch_numel) + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros( + pad_size, dtype=split.dtype, device=split.device, requires_grad=False + ) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length).to(device) + assert torch.all(split >= 0) and torch.all(split < 260) + pred = bpe_patcher_device(split) + pred_cpu = pred[0].cpu() + pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[ + : split.numel() - pad_size, : + ] # [batch_size * seq_len, vocab] + bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu) + bpe_patching_mask.append(bpe_patching_pred) + bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0) + bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape) + return bpe_patching_mask + + +def find_bpe_patcher_patch_start_ids( + tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True +): + bs, seq_len = tokens.shape + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=tokens.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[1] + token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1] + if token_input.shape[1] >= 1: + patch_start_mask = apply_bpe_patcher( + token_input, bpe_patcher, patching_batch_size, device + ) + assert ( + patch_start_mask.shape[1] + == tokens.shape[1] + include_next_token - preds_truncation_len + ), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}" + patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + else: + patch_start_ids = first_ids + return patch_start_ids + + +def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + threshold_add=None, + monotonicity=False, + include_next_token=True, +): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + bs, seq_len = entropies.shape[:2] + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[ + 1 + ] # remove the first preds because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = seq_len // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + # Assumes that there is at least one token going over the threshold + if monotonicity: + patch_start_mask = patch_start_mask_from_entropy_with_monotonicity( + entropies, threshold + ) + elif threshold_add is not None and threshold is not None: + patch_start_mask = patch_start_mask_global_and_monotonicity( + entropies, threshold, threshold_add + ) + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + return patch_start_ids + + +def rightpad(seq, pad_id, max_len): + return seq + [pad_id] * (max_len - len(seq)) + + +def find_bpe_delim_patch_start_ids(tokens, delim): + ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False) + out = [[0, 1] for _ in range(tokens.shape[0])] + for x, y in ids: + # start is at delim + 1, delim should be the last element in the patch. + out[x.item()].append(y.item() + 1) + max_len = max([len(elt) for elt in out]) + out = [rightpad(elt, tokens.shape[1], max_len) for elt in out] + patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device) + return patch_start_ids + + +def find_lookup_table_start_mask( + tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True +): + window_size = lookup_table.ndim + # Unfold the tensor to get sliding windows + unfolded = tokens.unfold(1, window_size, 1) + # Gather indices for each dimension + indices = [unfolded[..., i] for i in range(window_size)] + # Access the lookup table using the gathered indices + result = lookup_table[indices] + return result + + +def find_lookup_table_patch_start_ids( + tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True +): + bs, seq_len = tokens.shape + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=tokens.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[1] + window_size = lookup_table.ndim + assert window_size == 2, f"{window_size} != 2" + # output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape + token_input = ( + tokens if include_next_token else tokens[:, : -preds_truncation_len + 1] + ) + if token_input.shape[1] >= window_size: + patch_start_mask = find_lookup_table_start_mask( + token_input, lookup_table, include_next_token + ) + assert ( + patch_start_mask.shape[1] + == tokens.shape[1] + include_next_token - preds_truncation_len + ), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}" + patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + else: + patch_start_ids = first_ids + return patch_start_ids + + +def split_large_numbers(lst, m): + new_lst = [] + for i in lst: + if i > m: + while i > m: + new_lst.append(m) + i -= m + new_lst.append(i) + else: + new_lst.append(i) + assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" + return new_lst + + +class Patcher: + def __init__(self, patcher_args: PatcherArgs): + self.patcher_args = patcher_args + self.patching_mode = patcher_args.patching_mode + self.realtime_patching = patcher_args.realtime_patching + if self.realtime_patching: + assert ( + patcher_args.entropy_model_checkpoint_dir is not None + ), "Cannot require realtime patching without an entropy model checkpoint" + entropy_model = load_entropy_model( + patcher_args.entropy_model_checkpoint_dir + ) + entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) + self.entropy_model = entropy_model + else: + self.entropy_model = None + self.threshold = patcher_args.threshold + self.threshold_add = patcher_args.threshold_add + self.max_patch_length = patcher_args.max_patch_length + self.patch_size = patcher_args.patch_size + self.patching_batch_size = patcher_args.patching_batch_size + self.data_loader_patching = patcher_args.data_loader_patching + self.device = patcher_args.device + self.monotonicity = patcher_args.monotonicity + self.log_time = patcher_args.log_time + if self.log_time: + self.log = defaultdict(float) + + def patch( + self, + tokens: torch.Tensor, + include_next_token: bool = False, + preds: torch.Tensor | None = None, + entropies: torch.Tensor | None = None, + threshold: float = None, + ) -> torch.Tensor: + """ + tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched + Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) + -> output tensor: [batch_size, max_num_patches] + each tensor is processed independently and gets right padded with zeros. + + Patching with the following modes: + 1. patching_mode = None: static patch size + 2. patching_mode = "entropy": + calculate entropy of each token, allocate patches so that the total + number of patches is the same as static patching but choose to begin + patches on tokens where the model is most uncertain (highest entropy). + + When threshold is provided, it uses the threshold to decide when to + start a new patch. + 3. patching_mode = "space": + use space like tokens to define the patches. + 4. patching_mode = "bpe": + use bpe delim tokens to define the patches. + + To correctly patch the last token, it may be necessary to include the next token in the patch + lengths calculations. This is controlled by the include_next_token argument. + """ + bs, seq_len = tokens.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + scores = None + # STATIC + if self.patching_mode is None: + patch_lengths = torch.zeros( + (bs, math.ceil(seq_len_next_tok / self.patch_size)), + dtype=tokens.dtype, + device=tokens.device, + ).fill_(self.patch_size) + if seq_len_next_tok % self.patch_size != 0: + patch_lengths[:, -1] = seq_len_next_tok % self.patch_size + # ENTROPY + elif self.patching_mode == PatchingModeEnum.entropy: + if self.log_time: + s = time.time() + if entropies is not None: + scores = torch.tensor(entropies, dtype=torch.float32) + elif preds is not None: + scores = entropy(preds) + else: + start_entropies = time.time() + scores = calculate_entropies( + tokens, + self.entropy_model, + self.patching_batch_size, + self.device, + ) + if self.log_time: + self.log["calculate_entropies"] += time.time() - s + s = time.time() + patch_start_ids = find_entropy_patch_start_ids( + scores, + self.patch_size, + include_next_token=include_next_token, + threshold=threshold if threshold is not None else self.threshold, + threshold_add=self.threshold_add, + monotonicity=self.monotonicity, + ) + if self.log_time: + self.log["find_entropy_patch_start_ids"] += time.time() - s + s = time.time() + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + if self.log_time: + self.log["patch_lengths_from_start_ids"] += time.time() - s + s = time.time() + # BPE + elif self.patching_mode == PatchingModeEnum.bpe: + patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID) + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + elif self.patching_mode == PatchingModeEnum.bpe_patcher: + patch_start_ids = find_bpe_patcher_patch_start_ids( + tokens, + self.entropy_model, + self.patching_batch_size, + self.device, + include_next_token, + ) + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + # SPACE + elif self.patching_mode == PatchingModeEnum.space: + patch_start_ids = find_space_patch_start_ids(tokens) + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + else: + raise NotImplementedError(f"self.patching_mode {self.patching_mode}") + + # Apply any processing to patch lengths + if self.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + split_large_numbers(pl, self.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=tokens.dtype, device=tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + assert ( + torch.sum(patch_lengths) + == tokens.numel() + include_next_token * tokens.shape[0] + ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" + if self.log_time: + self.log["postprocessing_patch_lengths"] += time.time() - s + self.log["tokens"] += patch_lengths.sum().item() + return patch_lengths, scores diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py new file mode 100644 index 0000000..fadce45 --- /dev/null +++ b/bytelatent/distributed.py @@ -0,0 +1,478 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import atexit +import contextlib +import logging +import multiprocessing as mp +import os +import random +import shutil +import signal +import socket +import subprocess +import sys +import tempfile +from dataclasses import asdict, dataclass +from functools import lru_cache, partial, reduce +from itertools import chain +from typing import List, Optional, Tuple, Union + +import torch + +# for no recompute ops +import xformers.ops +from pydantic import BaseModel, ConfigDict +from torch import distributed as dist +from torch.distributed import ReduceOp +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed._tensor import DTensor +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, +) +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, +) + +from bytelatent.float8 import convert_linears_to_fp8 + +logger = logging.getLogger() + +# for selective AC +default_no_recompute_ops = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.c10d_functional.reduce_scatter_tensor.default, + torch.ops.xformers_flash.flash_fwd.default, + torch.ops.xformers.efficient_attention_forward_cutlass.default, +} + + +class DistributedArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dp_shard: int = ( + 1 # In how many shard to split the model weight. Typically number gpu in a node. + ) + dp_replicate: int = ( + 1 # How many times to replicate the model weight. Typically number of nodes. + ) + tp_size: int = 1 + selective_activation_checkpointing: bool = False + compile: bool = False + fsdp_type: str = "no_shard" + model_dtype: str = "bf16" + float8_recipe: str | None = None + float8_filter: str = r"layers\.[0-9]+\." + + matmul_allow_tf32: bool = False + allow_bf16_reduced_precision_reduction: bool = True + detect_anomaly: bool = False + + compile_cache_size_limit: int = 8 + + spawn_method: str = "forkserver" + + +class EnvironmentArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Use GNU openMP (GOMP) instead of Intel OpenMP [Intel Math Kernel Library (MKL)] + MKL_SERVICE_FORCE_INTEL: str = "GNU" + OMP_NUM_THREADS: str = "1" + MKL_NUM_THREADS: str = "1" + # faster intra-node collectives, seems to be a cluster specific flag + ENABLE_INTRA_NODE_COMM: str = "1" + # avoids OOMs with long context + TORCH_NCCL_AVOID_RECORD_STREAMS: str = "1" + # increasing NCCL timeout time before having some NCCL error 22 should give a 16s timeout + NCCL_IB_TIMEOUT: str = "22" + NCCL_DEBUG: str = "INFO" + TORCH_NCCL_ASYNC_ERROR_HANDLING: str = "1" + + +def get_device_mesh(distributed_args: DistributedArgs): + tp_size = distributed_args.tp_size + dp_replicate = distributed_args.dp_replicate + dp_shard = distributed_args.dp_shard + + assert ( + dp_replicate * dp_shard * tp_size == get_world_size() + ), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})" + + dims = [] + names = [] + if dp_replicate >= 1: + dims.append(dp_replicate) + names.append("dp_replicate") + if dp_shard > 1 or distributed_args.fsdp_type == "no_shard": + dims.append(dp_shard) + names.append("dp_shard") + if tp_size > 1: + dims.append(tp_size) + names.append("tp") + dims = tuple(dims) + names = tuple(names) + + return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names) + + +def dist_max(x: Union[int, float], mesh: DeviceMesh = None): + tensor = torch.tensor(x).cuda() + dist.all_reduce(tensor, op=ReduceOp.MAX, group=mesh.get_group() if mesh else None) + return tensor + + +def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): + tensor = torch.tensor(x).cuda() + dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) + return tensor + + +def dist_mean_dict(x): + r = dict() + for k in x: + r[k] = dist_mean(x[k]) + r[k] = r[k].item() if (r[k].dim() == 0) else r[k].tolist() + return r + + +@lru_cache() +def get_is_torch_run() -> bool: + return os.environ.get("LOCAL_RANK") is not None + + +@lru_cache() +def get_is_slurm_job() -> bool: + return "SLURM_JOB_ID" in os.environ and not get_is_torch_run() + + +@lru_cache() +def get_global_rank() -> int: + if get_is_torch_run(): + return int(os.environ["RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_PROCID"]) + else: + return 0 + + +@lru_cache() +def get_local_rank() -> int: + if get_is_torch_run(): + return int(os.environ["LOCAL_RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_LOCALID"]) + else: + return 0 + + +@lru_cache() +def get_world_size() -> int: + if get_is_torch_run(): + return int(os.environ["WORLD_SIZE"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_NTASKS"]) + else: + return 1 + + +@lru_cache() +def get_is_master() -> bool: + return get_global_rank() == 0 + + +@lru_cache() +def get_master_port(job_id: int) -> int: + if get_is_torch_run(): + return int(os.environ["MASTER_PORT"]) + else: + MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000) + rng = random.Random(job_id) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + +@lru_cache() +def get_master_addr() -> str: + if get_is_torch_run(): + return os.environ["MASTER_ADDR"] + elif get_is_slurm_job(): + hostnames = subprocess.check_output( + ["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]] + ) + return hostnames.split()[0].decode("utf-8") + else: + return "127.0.0.1" + + +def setup_env(env_args: EnvironmentArgs): + env_vars = env_args.model_dump() + + # When using Triton, it attempts to locate prebuilt kernels in a cache + # located at ~/.triton/cache, but when that's backed by NFS this can fail + # with a "OSError: [Errno 116] Stale file handle" error. If we were to set + # it to a local directory it would belong to the first user who created it + # and it would fail for the job of any other successive user assigned to + # that machine. To avoid all this mess we use a temporary per-process cache. + triton_cache_dir = tempfile.mkdtemp() + atexit.register(shutil.rmtree, triton_cache_dir, ignore_errors=True) + env_vars["TRITON_CACHE_DIR"] = triton_cache_dir + + # We change the tmp dir to /scratch in case it's slurm job + # This avoids filling up the host's usually limited tmpfs + # A full tmpfs leads to very slow creation of processes and weird bugs + if get_is_slurm_job(): + new_tmp = f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}" + if os.path.exists(new_tmp): + env_vars["TMP_DIR"] = new_tmp + + for name, value in env_vars.items(): + if os.environ.get(name) != str(value): + os.environ[name] = str(value) + logger.warning(f"WARNING: Setting {name} to {value}") + + +def setup_torch_distributed(dist_args): + """ + Handle single and multi-GPU / multi-node / SLURM jobs. + Initialize the following variables: + - global_rank + - world_size + """ + mp.set_start_method(dist_args.spawn_method) + with mp.Manager(): + pass + + local_rank = get_local_rank() + + os.environ["RANK"] = str(get_global_rank()) + os.environ["WORLD_SIZE"] = str(get_world_size()) + os.environ["MASTER_ADDR"] = get_master_addr() + os.environ["MASTER_PORT"] = str( + get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1))) + ) + + if get_is_torch_run(): + logger.info(f"Run launched with torchrun, local rank: {local_rank}") + elif get_is_slurm_job(): + logger.info(f"Run launched with slurm, local rank: {local_rank}") + else: + logger.info("Single GPU job") + + logger.info(f"ENV: {os.environ}") + + # set GPU device + assert 0 <= local_rank < 8 + if dist_args.matmul_allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + logger.warning( + f"WARNING: Setting torch.backends.matmul.allow_tf32 to True. This is faster but less accurate." + ) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + dist_args.allow_bf16_reduced_precision_reduction + ) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(init_method="env://", backend="nccl") + torch.autograd.set_detect_anomaly(dist_args.detect_anomaly) + + +def get_module(module, access_string): + names = access_string.split(sep=".") + return reduce(getattr, names, module) + + +def set_module(module, access_string, value): + names = access_string.split(sep=".") + parent = reduce(getattr, names[:-1], module) + setattr(parent, names[-1], value) + + +def default_fsdp_grouping_plan(n_layers: int) -> List[Tuple[str, bool]]: + return [(f"layers.{i}", i < n_layers - 1) for i in range(n_layers)] + + +def get_default_policy(no_recompute_ops=None): + no_recompute_ops = no_recompute_ops or default_no_recompute_ops + + def default_policy(ctx, func, *args, **kwargs): + return ( + CheckpointPolicy.MUST_SAVE + if func in no_recompute_ops + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return default_policy + + +@torch.no_grad() +def check_model_value_range( + model: torch.nn.Module, range: float = 1e3, std: float = 1e3 +): + for name, param in chain(model.named_parameters(), model.named_buffers()): + if isinstance(param, DTensor): + param = param.to_local() + + if param.numel() == 0: + logger.warning( + f"Model parameter {name} is empty, probably because of FSDP sharding" + ) + continue + + if torch.isnan(param).any() or torch.isinf(param).any(): + logger.warning(f"Model parameter {name} contains NaN or Inf") + + param_range = param.max() - param.min() + param_std = param.std() + if param_range > range: + logger.warning( + f"Model parameter {name} has a suspiciously large range ({param_range}): please check initialization and init_weights is defined and called" + ) + if param_std > std: + logger.warning( + f"Model parameter {name} has a suspiciously large standard deviation ({param_std}): please check initialization and init_weights is defined and called" + ) + if (param == 0).all(): + logger.warning( + f"Model parameter {name} is all zeros: it might be because of a missing initialization" + ) + + +def init_signal_handler(callable): + """ + Handle signals sent by SLURM for time limit / pre-emption. + """ + signal.signal(signal.SIGUSR2, callable) + logger.warning("Signal handler installed.") + + +def requeue_slurm_job(): + prod_id = int(os.environ["SLURM_PROCID"]) + logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) + if prod_id == 0 and os.environ.get("LAUNCH_WITH", "") != "DORA": + logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"]) + os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) + else: + logger.warning("Not the master process, no need to requeue.") + sys.exit(0) + + +@contextlib.contextmanager +def clean_env(): + distrib_names = ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", + "TORCHELASTIC_RUN_ID", + "DORA_FORCE_DISTRIB", + ) + cluster_env = { + x: os.environ.pop(x) + for x in os.environ + if x.startswith( + ("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_", "WANDB_") + ) + or x in distrib_names + } + try: + yield + finally: + os.environ.update(cluster_env) + + +def parallelize_model( + model, + device_mesh, + model_args, + distributed_args: DistributedArgs, + fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, + tp_parallelize=None, + no_recompute_ops=None, +): + if distributed_args.tp_size > 1: + assert ( + distributed_args.fsdp_type == "full_shard" + ), "Only full shard is supported for TP parallelism" + assert tp_parallelize is not None, "TP plan is required for TP parallelism" + assert ( + distributed_args.compile == False + ), "Compile is not supported for TP parallelism" + + tp_parallelize(model, device_mesh["tp"], model_args, distributed_args) + + if distributed_args.float8_recipe is not None: + if distributed_args.tp_size > 1: + raise RuntimeError("float8 is incompatible with tensor-parallelism for now") + model = convert_linears_to_fp8( + model, distributed_args.float8_recipe, distributed_args.float8_filter + ) + + param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ + distributed_args.model_dtype + ] + if ( + distributed_args.fsdp_type == "full_shard" + or distributed_args.fsdp_type == "no_shard" + ): + if distributed_args.fsdp_type == "no_shard": + assert ( + distributed_args.dp_shard == 1 + ), "dp_shard must be 1 for no_shard fsdp_type" + assert ( + device_mesh["dp_shard"].size() == 1 + ), "dp_shard must be 1 for no_shard fsdp_type" + + fsdp_config = dict( + mp_policy=( + MixedPrecisionPolicy( + param_dtype=param_dtype, + reduce_dtype=torch.float32, + ) + ), + mesh=( + device_mesh["dp_replicate", "dp_shard"] + if distributed_args.dp_shard > 1 + or distributed_args.fsdp_type == "no_shard" + else device_mesh["dp_replicate"] + ), + ) + + if fsdp_grouping_plan is None: + # Assume that the model has list of layers and group around it + fsdp_grouping_plan = default_fsdp_grouping_plan(len(model.layers)) + + for path, reshard_after_forward in fsdp_grouping_plan: + module = get_module(model, path) + set_module( + model, + path, + fully_shard( + module, **fsdp_config, reshard_after_forward=reshard_after_forward + ), + ) + + model = fully_shard(model, **fsdp_config, reshard_after_forward=True) + else: + raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}") + + if distributed_args.selective_activation_checkpointing: + model = checkpoint_wrapper( + model, + context_fn=partial( + create_selective_checkpoint_contexts, + get_default_policy(no_recompute_ops), + ), + ) + + if distributed_args.compile: + torch._dynamo.config.cache_size_limit = ( + distributed_args.compile_cache_size_limit + ) + model = torch.compile(model) + + return model diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py new file mode 100644 index 0000000..1bd1766 --- /dev/null +++ b/bytelatent/entropy_model.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import json +import os +import re + +import torch + +from bytelatent.transformer import LMTransformer, LMTransformerArgs + + +def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): + with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: + reloaded = json.loads(fr.read()) + + torch.set_default_dtype(torch.bfloat16) + model_params = reloaded["model"] + entropy_model = LMTransformer( + LMTransformerArgs( + dim=model_params["dim"], + n_layers=model_params["n_layers"], + n_heads=model_params["n_heads"], + max_seqlen=model_params["max_length"], + ffn_dim_multiplier=model_params["ffn_dim_multiplier"], + vocab_size=model_params["vocab_size"], + ) + ) + + entropy_model.load_state_dict( + torch.load(state_dict_path, map_location=device), strict=False + ) + entropy_model.to(device) + entropy_model = entropy_model.eval() + # no grads for the model: + for param in entropy_model.parameters(): + param.requires_grad = False + return entropy_model diff --git a/bytelatent/float8.py b/bytelatent/float8.py new file mode 100644 index 0000000..6476862 --- /dev/null +++ b/bytelatent/float8.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import re +import warnings +from typing import Callable + +import torch + +# avoid division by zero when calculating scale +EPS = 1e-12 + + +def scale(t, amax_t, dtype_t): + min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max + scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v + t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t) + return t_fp8, scale_t + + +def matmul( + first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias +): + first_fp8, scale_first = scale(first, amax_first, dtype_first) + second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t) + output = torch._scaled_mm( + first_fp8, + second_t_fp8.t(), + scale_a=scale_first, + scale_b=scale_second_t.t(), + bias=bias, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + return output + + +@torch._dynamo.allow_in_graph +class Fp8LinearFn(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b_t, bias): + amax_a = a.abs().amax(dim=-1, keepdim=True) + amax_b_t = b_t.abs().amax(dim=-1, keepdim=True) + out = matmul( + a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias + ) + + ctx.a_requires_grad = a.requires_grad + ctx.b_requires_grad = b_t.requires_grad + ctx.bias_requires_grad = bias.requires_grad if bias is not None else False + + ctx.save_for_backward(a, b_t, amax_b_t.max()) + + return out + + @staticmethod + def backward(ctx, grad_out): + a, b_t, amax_b = ctx.saved_tensors + + if ctx.a_requires_grad: + b = b_t.t().contiguous() + amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True) + amax_b = amax_b.repeat(b.shape[0], 1) + grad_a = matmul( + grad_out, + amax_grad_out, + torch.float8_e4m3fn, + b, + amax_b, + torch.float8_e4m3fn, + None, + ) + else: + grad_a = None + if ctx.b_requires_grad: + grad_b = grad_out.t() @ a + else: + grad_b = None + if ctx.bias_requires_grad: + grad_bias = grad_out.sum(dim=0) + else: + grad_bias = None + + return grad_a, grad_b, grad_bias + + +class Fp8Linear(torch.nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias) + out = out.unflatten(0, input.shape[:-1]) + return out + + +def named_replace( + fn: Callable[[torch.nn.Module, str], torch.nn.Module], + module: torch.nn.Module, + name="", +) -> torch.nn.Module: + for child_name, child_module in list(module.named_children()): + full_name = f"{name}.{child_name}" if name else child_name + new_child_module = named_replace(fn, child_module, full_name) + setattr(module, child_name, new_child_module) + module = fn(module, name) + return module + + +def convert_linears_to_fp8( + root_module: torch.nn.Module, recipe: str, filter: str +) -> torch.nn.Module: + if recipe not in ["rowwise"]: + raise RuntimeError(f"Unknown float8 recipe {recipe!r}") + + if recipe == "rowwise" and torch.__version__ < "2.5": + # We need https://github.com/pytorch/pytorch/pull/134781. + warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0") + + # Multi-kernel makes Inductor auto-tune between a regular "streaming"-based + # reduction kernel and a "persistent" reduction kernel. Since fp8 has some + # multi-pass steps (e.g., first get amax, then scale), persistent kernels + # should perform better. + torch._inductor.config.triton.multi_kernel = 1 + + filter_re = re.compile(filter) + + def replace(module: torch.nn.Module, name: str) -> torch.nn.Module: + if not isinstance(module, torch.nn.Linear) or not filter_re.search(name): + return module + if type(module) == torch.nn.Linear: + if recipe == "rowwise": + new_module = Fp8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + dtype=module.weight.dtype, + device=module.weight.device, + ) + new_module.weight = module.weight + new_module.bias = module.bias + else: + assert False, recipe + else: + assert False, str(type(module)) + return new_module + + out = named_replace(replace, root_module) + + # Force re-compile everything + torch._dynamo.reset_code_caches() + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + + return out diff --git a/bytelatent/logger.py b/bytelatent/logger.py new file mode 100644 index 0000000..6723a84 --- /dev/null +++ b/bytelatent/logger.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +import math +import sys +import time +from datetime import timedelta + +from bytelatent.distributed import get_global_rank, get_is_slurm_job + + +class LogFormatter(logging.Formatter): + """ + Custom logger for distributed jobs, displaying rank + and preserving indent from the custom prefix format. + """ + + def __init__(self): + self.start_time = time.time() + self.rank = get_global_rank() + self.show_rank = not get_is_slurm_job() # srun has --label + + def formatTime(self, record): + subsecond, seconds = math.modf(record.created) + curr_date = ( + time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) + + f".{int(subsecond * 1_000_000):06d}" + ) + delta = timedelta(seconds=round(record.created - self.start_time)) + return f"{curr_date} - {delta}" + + def formatPrefix(self, record): + fmt_time = self.formatTime(record) + if self.show_rank: + return f"{self.rank}: {record.levelname:<7} {fmt_time} - " + else: + return f"{record.levelname:<7} {fmt_time} - " + + def formatMessage(self, record, indent: str): + content = record.getMessage() + content = content.replace("\n", "\n" + indent) + # Exception handling as in the default formatter, albeit with indenting + # according to our custom prefix + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if content[-1:] != "\n": + content = content + "\n" + indent + content = content + indent.join( + [l + "\n" for l in record.exc_text.splitlines()] + ) + if content[-1:] == "\n": + content = content[:-1] + if record.stack_info: + if content[-1:] != "\n": + content = content + "\n" + indent + stack_text = self.formatStack(record.stack_info) + content = content + indent.join([l + "\n" for l in stack_text.splitlines()]) + if content[-1:] == "\n": + content = content[:-1] + + return content + + def format(self, record): + prefix = self.formatPrefix(record) + indent = " " * len(prefix) + content = self.formatMessage(record, indent) + return prefix + content + + +def set_root_log_level(log_level: str): + logger = logging.getLogger() + level: int | str = log_level.upper() + try: + level = int(log_level) + except ValueError: + pass + try: + logger.setLevel(level) # type: ignore + except Exception: + logger.warning( + f"Failed to set logging level to {log_level}, using default 'NOTSET'" + ) + logger.setLevel(logging.NOTSET) + + +def init_logger( + log_file: str | None = None, + *, + name: str | None = None, + level: str = "NOTSET", +): + """ + Setup logging. + + Args: + log_file: A file name to save file logs to. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + """ + set_root_log_level(level) + logger = logging.getLogger(name) + + # stdout: everything + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.NOTSET) + stdout_handler.setFormatter(LogFormatter()) + + # stderr: warnings / errors and above + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.WARNING) + stderr_handler.setFormatter(LogFormatter()) + + # set stream handlers + logger.handlers.clear() + logger.handlers.append(stdout_handler) + logger.handlers.append(stderr_handler) + + if log_file is not None and get_global_rank() == 0: + # build file handler + file_handler = logging.FileHandler(log_file, "a") + file_handler.setLevel(logging.NOTSET) + file_handler.setFormatter(LogFormatter()) + # update logger + logger = logging.getLogger() + logger.addHandler(file_handler) diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py new file mode 100644 index 0000000..77dc4d7 --- /dev/null +++ b/bytelatent/metrics.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +import logging +from collections import namedtuple +from dataclasses import asdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Union + +import torch +import torch.nn as nn +import wandb +from pydantic import BaseModel, ConfigDict + +from bytelatent.distributed import get_is_master + +logger = logging.getLogger() + + +class WandbArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + job_type: str | None = None + dir: str | None = None + project: str | None = None + entity: str | None = None + tags: list | None = None + group: str | None = None + name: str | None = None + notes: str | None = None + config_exclude_keys: list[str] | None = None + config_include_keys: list[str] | None = None + anonymous: str | None = None + mode: str | None = None + allow_val_change: bool | None = None + resume: Union[bool, str] | None = None + force: bool | None = None + tensorboard: bool | None = None + sync_tensorboard: bool | None = None + monitor_gym: bool | None = None + save_code: bool | None = None + id: str | None = None + fork_from: str | None = None + resume_from: str | None = None + + +class LoggingArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + freq: int = 10 # Log every freq optimizer steps + acc_freq: int | None = None # Log every acc_freq gradient accumulation steps + + wandb: WandbArgs | None = None + + +class MetricLogger: + def __init__(self, outdir: Path, args: Any | None = None): + self.outdir = outdir + self.jsonl_writer = None + self.args = args + + def open(self): + if self.jsonl_writer is None: + self.jsonl_writer = open(self.outdir, "a") + if ( + self.args is not None + and self.args.logging.wandb is not None + and get_is_master() + ): + run = wandb.init( + config=asdict(self.args), + **asdict(self.args.logging.wandb), + ) + + def log(self, metrics: dict[str, Any]): + if ( + self.args is not None + and self.args.logging.wandb is not None + and (wandb.run is not None) + ): + wandb.log(metrics, step=metrics["global_step"]) + + metrics.update({"created_at": datetime.now(timezone.utc).isoformat()}) + print(json.dumps(metrics), file=self.jsonl_writer, flush=True) + + def close(self): + if self.jsonl_writer is not None: + self.jsonl_writer.close() + self.jsonl_writer = None + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + +GPUMemStats = namedtuple( + "GPUMemStats", + [ + "max_active_gib", + "max_active_pct", + "max_reserved_gib", + "max_reserved_pct", + "num_alloc_retries", + "num_ooms", + "power_draw", + ], +) + + +class GPUMemoryMonitor: + """ + Class to monitor GPU memory usage + """ + + def __init__(self, device: str = "cuda:0"): + self.device = torch.device(device) # device object + self.device_name = torch.cuda.get_device_name(self.device) + self.device_index = torch.cuda.current_device() + self.device_capacity = torch.cuda.get_device_properties( + self.device + ).total_memory + self.device_capacity_gib = self._to_gib(self.device_capacity) + + # reset stats, clear cache + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + def _to_gib(self, memory_in_bytes): + # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 + _gib_in_bytes = 1024 * 1024 * 1024 + memory_in_gib = memory_in_bytes / _gib_in_bytes + return memory_in_gib + + def _to_pct(self, memory): + return 100 * memory / self.device_capacity + + def get_peak_stats(self): + cuda_info = torch.cuda.memory_stats(self.device) + + max_active = cuda_info["active_bytes.all.peak"] + max_active_gib = self._to_gib(max_active) + max_active_pct = self._to_pct(max_active) + + max_reserved = cuda_info["reserved_bytes.all.peak"] + max_reserved_gib = self._to_gib(max_reserved) + max_reserved_pct = self._to_pct(max_reserved) + + num_retries = cuda_info["num_alloc_retries"] + num_ooms = cuda_info["num_ooms"] + power_draw = torch.cuda.power_draw() + + if num_retries > 0: + logger.warning(f"{num_retries} CUDA memory allocation retries.") + if num_ooms > 0: + logger.warning(f"{num_ooms} CUDA OOM errors thrown.") + + return GPUMemStats( + max_active_gib, + max_active_pct, + max_reserved_gib, + max_reserved_pct, + num_retries, + num_ooms, + power_draw, + ) + + def reset_peak_stats(self): + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + + def __str__(self): + mem_stats = self.get_peak_stats() + display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, " + display_str += ( + f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak" + ) + return f"{display_str}" + + +def upload_train_to_wandb( + ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True +): + import json + from pathlib import Path + + import wandb + from omegaconf import OmegaConf + + cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml") + cfg = OmegaConf.to_container(cfg) + + if train: + wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity) + + with open(Path(ckpt_dir) / "metrics.jsonl") as f: + for l in f: + m = json.loads(l) + wandb.log(m, step=m["global_step"]) + + wandb.finish() + + if eval: + wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity) + + with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f: + for l in f: + m = json.loads(l) + wandb.log( + { + f"evals/{name.replace('/','.')}": value + for name, value in m.items() + if "/" in name + }, + step=m["global_step"], + ) + + wandb.finish() + + +def get_num_params(model: nn.Module) -> int: + """ + Get the total model params + Args : only_trainable: whether to only count trainable params + """ + numel = {n: p.numel() for n, p in model.named_parameters()} + return sum(numel.values()) diff --git a/bytelatent/model/__init__.py b/bytelatent/model/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/model/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py new file mode 100644 index 0000000..9332d19 --- /dev/null +++ b/bytelatent/model/blt.py @@ -0,0 +1,1064 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from enum import Enum, auto +from typing import Any, Optional + +import torch +from pydantic import ConfigDict, model_validator +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask +from typing_extensions import Self + +from bytelatent.base_transformer import ( + BaseTransformerArgs, + InitStdFactor, + TransformerBlock, +) +from bytelatent.data.patcher import Patcher, PatcherArgs +from bytelatent.model.local_models import LocalDecoder, LocalEncoder +from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.utils import downsample +from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID + + +def attention_flops_per_token(n_layers, seq_len, dim, causal): + # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 + return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1)) + + +def get_num_flop_per_token( + num_non_embed_params: int, n_layers: int, dim: int, seq_len: int +) -> int: + return 6 * num_non_embed_params + attention_flops_per_token( + n_layers, seq_len, dim, True + ) + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def setattrs(_self, **kwargs): + for k, v in kwargs.items(): + setattr(_self, k, v) + + +def get_encoder_dim_token_emb(args): + if args.dim_token is not None: + dim_token_emb = args.dim_token + elif args.use_local_encoder_transformer: + dim_token_emb = args.dim_local_encoder + else: + dim_token_emb = args.dim_global // args.patch_size + return dim_token_emb + + +def get_encoder_dim_patch_emb(args): + dim_patch_emb = None + if args.cross_attn_encoder: + if args.cross_attn_init_by_pooling: + dim_patch_emb = args.dim_local_encoder + else: + dim_patch_emb = args.dim_global + return dim_patch_emb + + +def get_global_dim_patch_emb(args): + dim_token_emb = get_encoder_dim_token_emb(args) + if args.cross_attn_encoder: + dim_patch_emb = dim_token_emb * args.cross_attn_k + elif ( + args.downsampling_by_pooling is None + or not args.downsampling_by_pooling + or len(args.downsampling_by_pooling) == 0 + ): + dim_patch_emb = dim_token_emb * args.patch_size + else: + dim_patch_emb = dim_token_emb * sum( + [ + pooling in args.downsampling_by_pooling + for pooling in ["avg", "min", "max"] + ] + ) + return dim_patch_emb + + +def get_decoder_dim_token_emb(args): + if args.share_encoder_decoder_emb: + dim_token_emb = get_encoder_dim_token_emb(args) + elif args.dim_token is not None: + dim_token_emb = args.dim_token + else: + dim_token_emb = args.dim_local_decoder + return dim_token_emb + + +def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: + if ngram_to_size_str is None: + return None + ngram_to_size = {} + for entry in ngram_to_size_str.split(","): + ngram, size = entry.split(":") + ngram = int(ngram) + size = int(size) + ngram_to_size[ngram] = size + return ngram_to_size + + +def fill_tokens(tokens, patch_size, fill_id): + batch_size, seq_len = tokens.shape + if seq_len % patch_size == 0: + return tokens + else: + remaining = patch_size - seq_len % patch_size + final_padding = tokens.new(batch_size, remaining).fill_(fill_id) + return torch.cat((tokens, final_padding), dim=1) + + +def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): + first_patch_length = patch_lengths[0, 0] + assert torch.all( + first_patch_length == patch_lengths[:, 0] + ), "first patch should always be the same size (1 for dynamic, patch_size for static)." + assert ( + first_patch_length - nb_boe == 1 + ), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})" + # Remove first patch from patch_ids for local decoder inputs and shift the last patch. + # decoder_patch_lengths = patch_lengths[:, 1:].clone() + # decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1) + decoder_patch_lengths = patch_lengths[:, 1:] + assert ( + decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0] + == patch_lengths.sum() + ), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}" + assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}" + decoder_patch_ids = patch_ids_from_lengths( + patch_lengths=decoder_patch_lengths, seq_len=seq_len + ) + return decoder_patch_ids + + +primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, +] + + +def rolling_polynomial_hash(t, hash_func_nb: int = 0): + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) + + +def get_rolling_polynomial_hash_fn(hash_func_nb: int = 0, group_size: int = 2): + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64) + prime_powers = torch.stack([prime**i for i in range(group_size)]) + + def rolling_polynomial_hash_fn(t): + return torch.sum(t * prime_powers, dim=-1) + + return rolling_polynomial_hash_fn + + +def byte_group_hash_function( + x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): + """ + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + bs, seq_len = x.shape + # x_numpy = x.numpy() + # hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False) + # for i in range(bs): + # for j in range(seq_len): + # start = max(j, j-group_size+1) + # end = j+1 + # hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash) + + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def create_patch_mask_from_ids( + patch_ids, num_patches, window=None, patches_as_queries=False +): + """ + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + is True if the patch id at position (i, j) is less than or equal to k. + Args: + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + num_patches (int): Total number of patches. + window (int): If not None, only considers patches within a window of size window. + patches_as_queries (bool): If True, the patches are used as queries + Returns: + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + """ + bs, seq_len = patch_ids.shape + if not patches_as_queries: + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + kv_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(bs, seq_len, num_patches) + ) + else: + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + q_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(bs, num_patches, seq_len) + ) + if window is None: + mask = q_ids == kv_ids + else: + mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) + return mask + + +def cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=1, + window=None, + block_mask=True, +): + bs = patch_ids.shape[0] + with torch.no_grad(): + # Create the patch mask + cross_mask = create_patch_mask_from_ids( + patch_ids, + patch_lengths.shape[1], + window=window, + patches_as_queries=patches_as_queries, + ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) + q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N + kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k + assert cross_mask.shape == ( + bs, + q_len, + kv_len, + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + if block_mask: + + def patch_mask(b, h, q_idx, kv_idx): + return cross_mask[b, q_idx, kv_idx] + + block_mask = create_block_mask( + patch_mask, + B=bs, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + _compile=True, + ) + return block_mask + else: + return torch.where( + cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) + ).unsqueeze( + 1 + ) # [bs, 1, q_len, kv_len] + + +def get_blt_input( + tokens: torch.Tensor, + enforce_patch_size_multiple: bool, + nb_boe: torch.Tensor, + patch_size: int, + boe_id: int, +): + """ + This function returns X_et, X_gt and X_dt, the encoder, global, and decoder + tokens respectively. + + Consider the input and target sequences: + X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] + Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] + with patch_size=4 + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + + Current without boe: + X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] + X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + --> lag fix: + X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] + X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + Dynamic (current): + X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + entropy patching: + input: 7, bos, 9, 10 + pred (high entropy): eos, 8, 10, eos + + X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] + X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] + X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + --> lag fix no boe (force single byte first patch): + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + input: 4, 7, bos, 9, 10 + pred (high entropy): 5, eos, 8, 10, eos + + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + Handle the last byte properly. + patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] + + + bpe delim + X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] + X_g = [[3], [4,5,6,7,], [eos,bos,], .. + X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. + Y = [4,5,6,7,, eos,bos, 8,9,, .. + + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + """ + batch_size, seq_len = tokens.shape + local_encoder_tokens = tokens + local_decoder_tokens = tokens + + if nb_boe > 0: + padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) + local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) + # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) + + # create global tokens, contains boe tokens and eos + # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) + # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] + # global_tokens += global_tokens.eq(0).int() * boe_id + # TODO: fix this when we want to use block causal in the global. + + if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: + local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + + return local_encoder_tokens, None, local_decoder_tokens + + +def patch_ids_from_lengths(patch_lengths, seq_len): + bs, num_patches = patch_lengths.shape + # Create a tensor of cumulative sums of the patch lengths + cum_d = torch.cat( + [ + torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1), + ], + dim=-1, + ) + patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum( + dim=-2 + ) - 1 + assert not ( + torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0 + ), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0" + return patch_ids + + +class ByteLatentTransformerArgs(BaseTransformerArgs): + model_config = ConfigDict(extra="forbid") + # Basic model configuration + seed: int = 42 + vocab_size: int = -1 + dim: int = 512 + n_layers: int = 8 + n_heads: int = 8 + # TODO: What is the purpose of this parameter? + weight_tying: bool = False + sliding_window: Optional[int] = None + + # Architecture and dimensions + dim_token: int = 256 + dim_global: int = 512 + dim_local_decoder: int = 512 + dim_local_encoder: int = 512 + n_layers_global: int = 8 + n_layers_local_decoder: int = 8 + n_layers_local_encoder: int = 8 + + # Tokenization and patching + tokenization_mode: str = "bpe" + patch_size: float | None = None + patching_mode: str | None = None + patching_threshold: float | None = None + patching_threshold_add: float | None = None + monotonicity: bool = False + patching_batch_size: int = 1 + patching_device: str = "cuda" + data_loader_patching: bool = False + max_patch_length: int | None = None + + # Encoder/Decoder configuration + tie_local_encoder_decoder_logits: bool = False + use_local_encoder_transformer: bool = False + encoder_lm_loss: bool = False + max_encoder_seq_length: int | None = None + pad_to_max_length: bool = False + encoder_enable_byte_ngrams: bool = False + encoder_enable_byte_group_hash: bool = False + ngram_vocab_sizes: int | None = None + + # Cross attention configurations + cross_attn_encoder: bool = False + cross_attn_decoder: bool = False + cross_attn_window_encoder: int | None = None + cross_attn_window_decoder: int | None = None + cross_attn_k: int | None = None + cross_attn_nheads: int | None = None + cross_attn_all_layers_decoder: bool = False + cross_attn_all_layers_encoder: bool = False + cross_attn_use_flex_attention: bool = True + cross_attn_init_by_pooling: bool = False + + # Encoder hash configurations + encoder_hash_byte_group_size: Any | None = None + encoder_hash_byte_group_vocab: int = 30000 + encoder_hash_byte_group_nb_functions: int = 3 + + # Model behavior and optimization + log_patch_lengths: bool = False + non_linearity: str = "swiglu" + use_rope: bool = True + recompute_fc1_out: bool = False + recompute_fc3_out: bool = False + recompute_attn: bool = True + custom_bwd: bool = False + layer_ckpt: str = "all" + efficient_attn: str | None = None + + # Architecture options + patch_only_encoder: bool = False + patch_only_decoder: bool = False + + # Initialization and attention + init_use_gaussian: bool = True + init_use_depth: str = "current" + attn_bias_type: str = "causal" + alpha_depth: str = "disabled" + max_length: int = 2048 + + # Norm configuration + norm_eps: float = 1e-5 + norm_affine: bool = True + pre_norm: bool = True + norm_type: str = "rmsnorm" + + # Additional configurations + multiple_of: int = 256 + ffn_dim_multiplier: float = 1.0 + dropout: float = 0 + output_size: int = -1 + + # Additional parameters from ModelArgs + architecture: str = "vanilla" + share_encoder_decoder_emb: bool = True + global_local_decoder_residual_layer: str | None = None + + tokenize_with_bpe_delimiter: bool = False + patching_thresholds_str: str | None = None + tie_local_encoder_decoder: bool = False + encoder_preds_low_entropy_toks: float | None = None + encoder_preds_random_toks: float | None = None + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + encoder_ngram_table_dir: str | None = None + encoder_ngram_to_size_str: str | None = None + + # Model architecture params + entropy_model_checkpoint_dir: str | None = None + entropy_model_is_ngram_model: bool = False + downsampling_by_pooling: str | None = None + n_heads_global: int = 8 + n_heads_local_decoder: int = 8 + n_heads_local_encoder: int = 8 + n_kv_heads: int | None = None + n_kv_heads_global: int | None = None + conv_kernel_size: int | None = None + local_attention_window_len: int | None = None + + # Performance optimization + sequence_parallel: bool = False + loss_parallel: bool = False + fuse_sequence_parallel: bool = False + use_fsdp: bool = True + attn_to_keep: str = "all" + + # RoPE parameters + rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False + + # Parameter mixing + pm_size: int = 0 + + # Logging + full_logging_n_layers: int = 4 + + # Special token config + eos_id: int | None = None + + @model_validator(mode="after") + def check_hash_byte_sizes(self) -> Self: + if ( + self.encoder_hash_byte_group_size is not None + and type(self.encoder_hash_byte_group_size) == str + ): + self.encoder_hash_byte_group_size = [ + int(x) + for x in self.encoder_hash_byte_group_size.split(",") + if len(x) > 0 + ] + return self + + +class LocalEncoderArgs(ByteLatentTransformerArgs): + # Local encoder specific dimensions + n_heads_local_encoder: int = 8 + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with local encoder specific values + self.dim = self.dim_local_encoder + self.n_layers = self.n_layers_local_encoder + self.n_heads = self.n_heads_local_encoder + self.cross_attn_decoder = False + self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None + self.attn_bias_type = "local_block_causal" + + +class GlobalTransformerArgs(ByteLatentTransformerArgs): + # Global encoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with global encoder specific values + self.dim = self.dim_global + self.n_layers = self.n_layers_global + self.n_heads = self.n_heads_global + self.n_kv_heads = self.n_kv_heads_global + self.local_attention_window_len = None + self.cross_attn_encoder = False + self.cross_attn_decoder = False + + +class LocalDecoderArgs(ByteLatentTransformerArgs): + # Local decoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with local decoder specific values + self.dim = self.dim_local_decoder + self.n_layers = self.n_layers_local_decoder + self.n_heads = self.n_heads_local_decoder + self.cross_attn_encoder = False + self.cross_attn_init_by_pooling = False + self.attn_bias_type = "local_block_causal" + + +def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransformer: + global_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_global, + n_layers=args.n_layers_global, + n_heads=args.n_heads_global, + n_kv_heads=args.n_kv_heads_global, + local_attention_window_len=None, + dim_token_emb=get_global_dim_patch_emb(args), + dim_patch_emb=None, + cross_attn_encoder=False, + cross_attn_decoder=False, + ), + ) + + return GlobalTransformer(global_args) + + +def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: + # First deep copy the original args + # Replace with local encoder specific values + local_encoder_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + attn_bias_type="local_block_causal", + ), + ) + + return LocalEncoder(local_encoder_args) + + +def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: + # First deep copy the original args + local_decoder_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + cross_attn_encoder=False, + cross_attn_init_by_pooling=False, # states are already defined + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + ), + ) + + return LocalDecoder(local_decoder_args) + + +class EmbeddingType(Enum): + HASH_TOK = auto() + NGRAM = auto() + + +def init_embeddings( + args, + embedding_type: EmbeddingType, + local_encoder_dim: int, + encoder_hash_byte_group_size: list = None, +): + if ( + embedding_type == EmbeddingType.HASH_TOK + and args.encoder_hash_byte_group_size is None + ): + return None + if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None: + return None + + embeddings = [] + + if embedding_type == EmbeddingType.HASH_TOK: + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + for _ in range(args.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + elif embedding_type == EmbeddingType.NGRAM: + encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str) + emb_dim = local_encoder_dim + OFFSET = 4 # This should be passed as parameter if it's variable + for ngram_vocab_size in encoder_ngram_to_size.values(): + embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) + + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with tok_embeddings method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class ByteLatentTransformer(nn.Module): + """ + The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences + by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, + and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for + improved performance and inference efficiency. + """ + + def __init__(self, args: ByteLatentTransformerArgs): + super().__init__() + + # General configuration + self.weight_tying = args.weight_tying + self.sliding_window = args.sliding_window + self.patch_size = args.patch_size + self.patching_mode = args.patching_mode + self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( + BOE_ID, + BOS_ID, + PAD_ID, + EOS_ID, + ) + self.downsampling_by_pooling = args.downsampling_by_pooling + self.patching_threshold = args.patching_threshold + self.dim = args.dim + self.init_base_std = args.init_base_std + self.init_std_factor = InitStdFactor(args.init_std_factor) + self.max_seqlen = args.max_seqlen + + # Cross attention configuration + self.cross_attn_encoder = args.cross_attn_encoder + self.cross_attn_decoder = args.cross_attn_decoder + self.cross_attn_k = args.cross_attn_k + self.cross_attn_window_encoder = args.cross_attn_window_encoder + self.cross_attn_window_decoder = args.cross_attn_window_decoder + self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention + + # Encoder hash configuration + self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size + self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = ( + args.encoder_hash_byte_group_nb_functions + ) + + # ByteLatent modules + self.local_encoder = create_local_encoder(args) + self.global_transformer = create_global_transformer(args) + self.local_decoder = create_local_decoder(args) + self.encoder_hash_tok_embedding = init_embeddings( + args, + EmbeddingType.HASH_TOK, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + ) + self.encoder_ngram_embedding = init_embeddings( + args, + EmbeddingType.NGRAM, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=None, + ) + self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + + # Transformer layers + self.layers = nn.ModuleList( + [TransformerBlock(args) for _ in range(args.n_layers)] + ) + + # Encoder ngram embedding tables + self.encoder_ngram_embedding = None + if args.encoder_enable_byte_ngrams: + self.encoder_ngram_embedding = nn.ModuleList() + assert args.ngram_vocab_sizes is not None + self.encoder_ngram_to_size = parse_ngram_to_size( + args.encoder_ngram_to_size_str + ) + ngram_emb_dim = self.local_encoder.dim + for ngram_vocab_size in self.encoder_ngram_to_size.values(): + self.encoder_ngram_embedding.append( + nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) + ) + + # Output layer + assert args.vocab_size > 0, "vocab_size must be greater than 0" + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + if args.weight_tying: + self.output.weight = self.tok_embeddings.weight + + # Patcher module + if not args.data_loader_patching: + self.patcher = Patcher( + PatcherArgs( + patch_size=args.patch_size, + patching_mode=args.patching_mode, + patching_threshold=args.patching_threshold, + patching_threshold_add=args.patching_threshold_add, + monotonicity=args.monotonicity, + max_patch_length=args.max_patch_length, + ) + ) + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ngram_ids: Optional[torch.Tensor] = None, + ): + # Ensure ngram_ids is either a tensor or None + assert ( + isinstance(ngram_ids, torch.Tensor) or ngram_ids is None + ), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}" + + bs, N = tokens.shape # Batch size and sequence length + + # Get megabyte inputs + nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1) + local_encoder_tokens, _, local_decoder_tokens = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=nb_boe, + patch_size=self.patch_size, + boe_id=self.boe_id, + ) + + # Patching + if patch_lengths is None: + assert ( + getattr(self, "patcher", None) is not None + ), "Patcher not defined and no patch_lengths passed." + patch_lengths, tok_scores = self.patcher.patch( + local_encoder_tokens, + include_next_token=True, + threshold=self.patcher.threshold, + ) + else: + if nb_boe > 0: + patch_lengths[:, 0] += nb_boe + + assert torch.min(patch_lengths) >= 0 + + # Generate patch IDs from patch_lengths + patch_ids = patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + assert torch.max(patch_ids) + 1 <= torch.max( + (patch_lengths != 0).sum(dim=-1) + ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + + cross_attn_mask_enc = None + # Cross-attention encoder + if self.cross_attn_encoder: + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=True, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_encoder, + block_mask=self.cross_attn_use_flex_attention, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, + ) + + # N-gram table embeddings + if self.encoder_ngram_embedding is not None: + assert ngram_ids is not None, "ngram_ids must be provided" + if local_encoder_embeds is None: + local_encoder_embeds = self.local_encoder.tok_embeddings( + local_encoder_tokens + ) + assert len(ngram_ids) == len( + self.encoder_ngram_embedding + ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}" + for i in range(ngram_ids.shape[0]): + ngram_embedding = self.encoder_ngram_embedding[i] + ngram_embeds = ngram_embedding(ngram_ids[i]) + assert ( + local_encoder_embeds.shape == ngram_embeds.shape + ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}" + local_encoder_embeds = local_encoder_embeds + ngram_embeds + + # Local encoder + h_cross = None + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=h_cross if self.cross_attn_encoder else None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + if not self.cross_attn_encoder: + assert ( + patch_ids.shape[1] == h_encoder.shape[1] + ), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}" + h = downsample( + h_encoder, + patch_lengths.shape[1], + patch_lengths, + patch_ids, + downsampling_by_pooling=self.downsampling_by_pooling, + patch_size=self.patch_size, + ) + else: + # Reshape h_cross + h = h_cross.view(bs, patch_lengths.shape[1], -1) + + # Global transformer + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.eos_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.eos_id + + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, + ) + + # Unpatching + dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] + + # Generate decoder patch IDs + decoder_patch_ids = decoder_patch_ids_from_lengths( + patch_lengths, nb_boe, local_decoder_tokens.shape[-1] + ) + assert ( + torch.max(decoder_patch_ids) + 1 <= h.shape[1] + ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + assert ( + decoder_patch_ids.shape[1] == dec_embeds.shape[1] + ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + + # Cross-attention decoder + if not self.cross_attn_decoder: + h = torch.gather( + h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + ) + cross_attn_mask_dec = None + assert local_decoder_tokens.shape == h.shape[:-1] + else: + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_decoder, + block_mask=self.cross_attn_use_flex_attention, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=dec_embeds, + patch_embeds=h, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + ) + return output + + def reset_parameters(self, init_std=None): + # Either use fixed base std or sqrt model dim + init_std = init_std or (self.dim ** (-0.5)) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + def init_weights(self): + self.reset_parameters() + self.init_base_std = self.init_base_std or (self.dim ** (-0.5)) + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(self.init_base_std, factor) + + self.local_decoder.init_weights(self.init_base_std) + self.global_transformer.init_weights(self.init_base_std) + self.local_encoder.init_weights(self.init_base_std) + + for emb in self.encoder_hash_tok_embedding: + nn.init.trunc_normal_( + emb.weight, + mean=0.0, + std=self.init_base_std, + a=-3 * self.init_base_std, + b=3 * self.init_base_std, + ) diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py new file mode 100644 index 0000000..8255504 --- /dev/null +++ b/bytelatent/model/local_models.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask +from xformers.ops import AttentionBias + +from bytelatent.base_transformer import ( + InitStdFactor, + RMSNorm, + RotaryEmbedding, + TransformerBlock, +) +from bytelatent.model.transformer import CrossAttention +from bytelatent.model.utils import create_causal_mask, downsample +from bytelatent.tokenizers.blt_tokenizer import BOE_ID + +logger = logging.getLogger() + + +class LocalModelBase(nn.Module): + def __init__(self, args): + super().__init__() + + self.dim = args.dim + self.dropout = args.dropout + self.vocab_size = args.vocab_size + args.pm_size + self.patch_size = args.patch_size + + self.efficient_attn = args.efficient_attn + self.sliding_window = args.sliding_window + self.use_rope = args.use_rope + self.init_std_factor = args.init_std_factor + self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) + self.cross_attn_k = getattr(args, "cross_attn_k", None) + + self.boe_id = BOE_ID + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.layers = nn.ModuleList( + [TransformerBlock(args) for _ in range(args.n_layers)] + ) + + self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + if not self.use_rope: + self.pos_embeddings = nn.Embedding(args.max_length, args.dim) + else: + self.rope = RotaryEmbedding( + theta=args.rope_theta, + head_dim=args.head_dim or args.dim // args.n_heads, + max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + ) + self.pos_embeddings = None + + self.token_embedding_projection = ( + nn.Linear(args.dim_token_emb, args.dim, bias=False) + if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(args) + + def _should_create_patch_projection(self, args): + dimension_mismatch = ( + getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim + ) + + # Check cross attention conditions + cross_attn_conditions = ( + hasattr(args, "cross_attn_encoder") + and args.cross_attn_encoder + and getattr(args, "cross_attn_init_by_pooling") + ) or ( + hasattr(args, "cross_attn_decoder") + and args.cross_attn_decoder + and getattr(args, "cross_attn_init_by_pooling") + ) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, args): + if not self._should_create_patch_projection(args): + return None + + output_dim = args.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=args.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + def init_weights(self, init_std=None): + self.rope.reset_parameters() + + init_std = init_std or (self.dim ** (-0.5)) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if self.pos_embeddings is not None: + nn.init.trunc_normal_( + self.pos_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(init_std, factor) + + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.patch_embedding_projection is not None: + nn.init.trunc_normal_( + self.patch_embedding_projection.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if hasattr(self, "output"): + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.cross_attn_layers is not None: + for depth, layer in enumerate(self.cross_attn_layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(init_std, factor) + + +class LocalEncoder(LocalModelBase): + def __init__(self, args): + super().__init__(args) + self.output_proj = ( + args.patching_mode in ["entropy", "probmax"] + ) and args.entropy_model_checkpoint_dir is None + + self.apply_transformer = args.use_local_encoder_transformer + self.downsampling_by_pooling = args.downsampling_by_pooling + self.patch_only = args.patch_only_encoder + self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = args.cross_attn_encoder + self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling + self.cross_attn_nheads = args.cross_attn_nheads + + if self.cross_attn_encoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + CrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=args.norm_eps, + ) + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert ( + self.expects_hash_embeddings + ), "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + + h = self.apply_embedding(tokens, embeds) + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and ( + i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder + ): + patch_embeds = self.apply_cross_attention( + h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask + ) + + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache + + def apply_cross_attention( + self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask + ): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = downsample( + h, + num_patches, + patch_ids=patch_ids, + downsampling_by_pooling=self.downsampling_by_pooling, + patch_size=self.patch_size, + ) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + patch_embeds += patch_embeds_cross + return patch_embeds + + +class LocalDecoder(LocalModelBase): + def __init__(self, args): + super().__init__(args) + + # Model configuration flags + self.patch_only = args.patch_only_decoder + self.expects_embeddings = args.share_encoder_decoder_emb + self.cross_attn_decoder = args.cross_attn_decoder + self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling + self.cross_attn_nheads = args.cross_attn_nheads + + if self.cross_attn_decoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + CrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=args.norm_eps, + ) + ) + + self.output = nn.Linear( + self.dim, + args.vocab_size, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + assert embeds is not None, "Embeddings must be provided" + + if mask is None: + mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + + h = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + if patch_embeds is not None and not self.cross_attn_decoder: + h = h + patch_embeds + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and ( + i == 0 or self.cross_attn_all_layers_decoder + ): + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, + ) + h = h + h_cross + + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache diff --git a/bytelatent/model/transformer.py b/bytelatent/model/transformer.py new file mode 100644 index 0000000..24dc057 --- /dev/null +++ b/bytelatent/model/transformer.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask +from xformers.ops import AttentionBias + +from bytelatent.base_transformer import ( + BaseTransformer, + RMSNorm, + flex_attention_comp, + repeat_kv, +) +from bytelatent.model.utils import create_causal_mask + +logger = logging.getLogger() + + +class CrossAttention(nn.Module): + """ + CrossAttention block to attend to the encoder states from the decoder. + Rope is not supported. + """ + + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + output = self.wo(output.reshape(output_shape)) + + return x + output + + def init_weights(self, base_std: float, factor: float = 1.0): + std = base_std * factor + + nn.init.trunc_normal_( + self.wq.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wk.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wv.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + output_std = std / (2**0.5) + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=output_std, + a=-3 * output_std, + b=3 * output_std, + ) + self.cross_attn_norm_q.reset_parameters() + self.cross_attn_norm_kv.reset_parameters() + + +class GlobalTransformer(BaseTransformer): + def __init__(self, args): + super().__init__(args) + self.dropout = args.dropout + self.sliding_window = args.sliding_window + self.efficient_attn = args.efficient_attn + + self.token_embedding_projection = None + if args.dim_token_emb is not None and args.dim_token_emb != self.dim: + self.token_embedding_projection = nn.Linear( + args.dim_token_emb, + args.dim, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ + Similar to BaseTransformer.forward, but with an additional embeds argument + and projection to the token space. + """ + bs, seqlen = tokens.shape + attn_impl = self.efficient_attn + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask(seqlen, attn_impl, self.sliding_window) + ) + + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) + + h = F.dropout(h, p=self.dropout, training=self.training) + + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + return h, cache + + def init_weights(self, init_base_std: float): + super().init_weights() + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=init_base_std, + a=-3 * init_base_std, + b=3 * init_base_std, + ) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py new file mode 100644 index 0000000..ce52a30 --- /dev/null +++ b/bytelatent/model/utils.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +from torch.nn.attention.flex_attention import create_block_mask +from xformers.ops import fmha + + +def patch_reduce(h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros( + (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device + ) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + + +def concat_downsample(h, patch_lengths, patch_size): + # The assumption in this function is that seq_len = patch_size * num_patches. + bs, seq_len, emb_dim = h.shape + patch_end_ids = torch.cumsum(patch_lengths, dim=1) + patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to( + patch_end_ids.device + ) + # Is clamp ok here? + patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1]) + patch_ids = patch_ids.view(bs, -1, emb_dim) + # after gather h.shape = [batch_size, seq_len, dim] + h = torch.gather(h, 1, patch_ids) + h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1)) + return h + + +def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids): + cat = [] + if "avg" in pooling_mode or "mean" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids)) + if "min" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids)) + if "max" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids)) + assert len(cat) > 0 + h = torch.cat(cat, dim=-1) + return h + + +def downsample( + h, + num_patches, + patch_lengths=None, + patch_ids=None, + downsampling_by_pooling=None, + patch_size=4, +): + """ + Downsampling: + a. concatenating embeddings in the patch + Note: with dynamic patching, patch the last patch_size tokens. + b. pooling embeddings in the patch + """ + # input: h.shape = [batch_size, seq_len, dim] + # input: pool h.shape = [batch_size, seq_len / patch_size, dim] + # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep + if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0: + # By pooling + max_num_patches = num_patches + assert patch_ids is not None + h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids) + else: + # TODO: remove this condition + # By concatenating (fixed lengths patching) + assert patch_lengths is not None + h = concat_downsample(h, patch_lengths, patch_size) + return h + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def create_causal_mask(seqlen, attn_impl, sliding_window): + if sliding_window is not None and attn_impl == "xformers": + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) + elif attn_impl == "xformers": + return fmha.attn_bias.LowerTriangularMask() + elif attn_impl == "sdpa": + return "causal" + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + elif attn_impl == "fmha": + return None + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) diff --git a/bytelatent/optim.py b/bytelatent/optim.py new file mode 100644 index 0000000..c6e1829 --- /dev/null +++ b/bytelatent/optim.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +import math +from functools import partial + +from pydantic import BaseModel, ConfigDict +from torch import nn +from torch.optim import AdamW, lr_scheduler + +logger = logging.getLogger() + + +class OptimArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + lr: float = 3e-4 + weight_decay: float = 0.1 + epsilon: float = 1e-8 + beta1: float = 0.9 + beta2: float = 0.95 + clip: float = 1.0 + + scheduler: str = "cosine" + warmup: int = 2000 + lr_min_ratio: float = 0.1 + cycle_length: float = 1.0 + cosine_theta: float = 1.0 + annealing_step: int = 1000 + decay_fraction: float = 0.1 + + exp_factor: float = 0.5 + + +def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float: + if step < warmup: + lr = float(step) / warmup + elif step <= n_steps: + s = float(step - warmup) / (n_steps - warmup) + lr = s * min_ratio + (1 - s) + else: + lr = min_ratio + return lr + + +def lr_inv_sqrt(step: int, warmup: int, exp_factor: float, min_ratio: float) -> float: + if step < warmup: + lr = float(step) / warmup + else: + lr = max((warmup**exp_factor) / (step**exp_factor), min_ratio) + return lr + + +def lr_cosine( + step: int, + warmup: int, + n_steps: int, + cycle_length: float, + theta: float, + min_ratio: float, +) -> float: + sign = ((step // (n_steps * cycle_length)) % 2) * -2 + 1 + if step < warmup: + lr = float(step) / warmup + elif step <= n_steps: + s = float(step - warmup) / (n_steps - warmup) + lr = min_ratio + 0.5 * (1 - min_ratio) * ( + sign * math.cos(math.pi * s**theta / cycle_length) + 1 + ) + else: + lr = min_ratio + return lr + + +def lr_wsd( + step: int, + warmup: int, + n_steps: int, + decay_fraction: float, + cycle_length: float, + min_ratio: float, +) -> float: + """ + UNDERSTANDING WARMUP-STABLE-DECAY LEARNING RATES: A RIVER VALLEY LOSS LANDSCAPE PERSPECTIVE + https://arxiv.org/pdf/2410.05192 + """ + cycle_num = step // int(n_steps * cycle_length) + 1 + curr_n_steps = int(n_steps * cycle_length) * cycle_num + decay_length = int(curr_n_steps * decay_fraction) + + if step < warmup: + lr = float(step) / warmup + elif step <= curr_n_steps - decay_length: + lr = 1.0 + elif step > curr_n_steps - decay_length and step <= curr_n_steps: + # Linear interpolation gives similar results + # slope = -(1.0 - min_ratio) / decay_length + # intercept = min_ratio + ((1.0 - min_ratio) * curr_n_steps) / decay_length + # lr = slope * step + intercept + + step = step - (curr_n_steps - decay_length) + lr = 1 / ((step / curr_n_steps) * (1 / min_ratio) + (1 - step / curr_n_steps)) + else: + lr = min_ratio + + return lr + + +def build_lr_fn(args: OptimArgs, n_steps: int): + if args.scheduler == "constant": + lr_fn = lambda x: 1.0 + elif args.scheduler == "linear": + lr_fn = partial( + lr_linear, warmup=args.warmup, n_steps=n_steps, min_ratio=args.lr_min_ratio + ) + elif args.scheduler == "inv_sqrt": + lr_fn = partial( + lr_inv_sqrt, + warmup=args.warmup, + exp_factor=args.exp_factor, + min_ratio=args.lr_min_ratio, + ) + elif args.scheduler == "cosine": + lr_fn = partial( + lr_cosine, + warmup=args.warmup, + n_steps=n_steps, + cycle_length=args.cycle_length, + theta=args.cosine_theta, + min_ratio=args.lr_min_ratio, + ) + elif args.scheduler == "wsd": + assert args.decay_fraction < args.cycle_length + lr_fn = partial( + lr_wsd, + warmup=args.warmup, + n_steps=n_steps, + decay_fraction=args.decay_fraction, + cycle_length=args.cycle_length, + min_ratio=args.lr_min_ratio, + ) + else: + raise NotImplementedError(f"Unknown scheduler: {args.scheduler}") + return lr_fn + + +def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int): + logger.info("Starting build of optimizer...") + optimizer = AdamW( + model.parameters(), + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + eps=args.epsilon, + fused=True, # Faster optim.step but can throw errors + ) + + # scheduler + lr_fn = build_lr_fn(args, n_steps) + scheduler = lr_scheduler.LambdaLR(optimizer, lr_fn) + + logger.info("Done with build of optimizer.") + return optimizer, scheduler diff --git a/bytelatent/plotting/__init__.py b/bytelatent/plotting/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/plotting/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/plotting/config_entropy_figure.yaml b/bytelatent/plotting/config_entropy_figure.yaml new file mode 100644 index 0000000..4d7bfd7 --- /dev/null +++ b/bytelatent/plotting/config_entropy_figure.yaml @@ -0,0 +1,3 @@ +data_path: plot_data/entropy_figure.json +chart_path: figures/entropy_figure.pdf +# chart_path: figures/entropy_figure.pdf diff --git a/bytelatent/plotting/config_scaling_figures.yaml b/bytelatent/plotting/config_scaling_figures.yaml new file mode 100644 index 0000000..cda85c2 --- /dev/null +++ b/bytelatent/plotting/config_scaling_figures.yaml @@ -0,0 +1,4 @@ +df_dir: /home/par/blt_df +output_chart_dir: figures/ +frame_files: + ["4b_df.json", "500m_df.json", "scaling_arch_df.json", "scaling_df.json"] diff --git a/bytelatent/plotting/entropy_figure.py b/bytelatent/plotting/entropy_figure.py new file mode 100644 index 0000000..c1966a1 --- /dev/null +++ b/bytelatent/plotting/entropy_figure.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +import sys +from pathlib import Path + +import altair as alt +import pandas as pd +from omegaconf import OmegaConf +from pydantic import BaseModel + + +class PlotEntropiesConfig(BaseModel): + data_path: str | None + chart_path: str + + class Config: + extra = "forbid" + + +class PlotEntropiesData(BaseModel): + text: str + threshold: float = 1.335442066192627 + dataframe_json: str | None + + class Config: + extra = "forbid" + + +def main(): + config_path = sys.argv[1] + file_config = OmegaConf.load(config_path) + # Omit program name and config file name + cli_conf = OmegaConf.from_cli(sys.argv[2:]) + conf_dict = OmegaConf.to_container( + OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True + ) + plot_config = PlotEntropiesConfig(**conf_dict) + with open(plot_config.data_path) as f: + json_data = f.read() + plot_data = PlotEntropiesData.model_validate_json(json_data) + df = pd.read_json(plot_data.dataframe_json) + + x_ticks = [] + for row in df.itertuples(): + position = row.position + token = row.tokens + x_ticks.append(f"{str(position).zfill(3)}|{token}") + df["position_with_token"] = x_ticks + print(df) + + x_axis = alt.Axis( + labelExpr="split(datum.label, '|')[1]", + grid=False, + labelOverlap=False, + labelAngle=0, + ) + width = 1200 + height = 150 + base = alt.Chart(df).properties(width=width, height=height) + points = base.mark_line(point=True).encode( + x=alt.X("position_with_token:O", title=None, axis=x_axis), + y=alt.Y( + "entropies", + title="Entropy of Next Byte", + ), + ) + rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( + y=alt.datum(plot_data.threshold), + ) + patch_rules = ( + alt.Chart(df[df["start"] > 0]) + .properties(width=width, height=height) + .mark_rule(color="#474747", strokeDash=[4, 2]) + .encode(x=alt.X("position_with_token:O", axis=x_axis)) + ) + + chart = patch_rules + rule + points + chart = chart.configure_axis(labelFontSize=15, titleFontSize=15) + path = Path(plot_config.chart_path) + path.parent.mkdir(exist_ok=True) + chart.save(path) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/plotting/scaling_figures.py b/bytelatent/plotting/scaling_figures.py new file mode 100644 index 0000000..a49624a --- /dev/null +++ b/bytelatent/plotting/scaling_figures.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import sys +from pathlib import Path + +import altair as alt +import pandas as pd +import pydantic +from omegaconf import OmegaConf + + +class ScalingPlotsConfig(pydantic.BaseModel): + df_dir: str + output_chart_dir: str + frame_files: list[str] + + class Config: + extra = "forbid" + + +def determine_family(key: str): + if key.startswith("Megabyte++"): + return "Megabyte++" + elif key.startswith("BLT"): + return "BLT" + elif key.startswith("LLaMA"): + return "LLaMA" + elif key.startswith("Space"): + return "Space" + + +file_to_vars = {} + + +def create_chart(df: pd.DataFrame, output_file: str): + df["metric"] = df["bpb/not_heldout.jsonl"] + df["family"] = df["key"].map(determine_family) + model_domain = [ + "BLT Space ps=6", + "BLT Space w/o cross-attn", + "SpaceByte", + "LLaMA 3 BPE", + "Megabyte++ ps=4", + "Megabyte++ ps=6", + ] + color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"] + shape_range = [ + "circle", + "square", + "cross", + "diamond", + "triangle-up", + "triangle-down", + ] + color_scale = alt.Scale(domain=model_domain, range=color_range) + shape_scale = alt.Scale( + domain=model_domain, + range=shape_range, + ) + base_chart = alt.Chart(df).encode( + x=alt.X("flops", title="Training FLOPS") + .scale(type="log", domain=[2e20, 1.25e22]) + .axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]), + y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False), + ) + lines = base_chart.encode( + color=alt.Color("key", title="Model Color", scale=color_scale, legend=None), + strokeDash=alt.StrokeDash("family", title="Model Family", legend=None), + ).mark_line() + points = base_chart.encode( + color=alt.Color("key", title="Model", scale=color_scale), + shape=alt.Shape("key", title="", scale=shape_scale), + ).mark_point(size=70) + chart = ( + (lines + points) + .resolve_scale( + color="independent", + shape="independent", + # strokeDash="independent", + ) + .configure_legend(orient="right") + .properties(height=300, width=400) + ) + print("Saving", output_file) + chart.save(output_file) + + +def main(): + config_path = sys.argv[1] + file_config = OmegaConf.load(config_path) + # Omit program name and config file name + cli_conf = OmegaConf.from_cli(sys.argv[2:]) + conf_dict = OmegaConf.to_container( + OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True + ) + plot_config = ScalingPlotsConfig(**conf_dict) + df_dir = Path(plot_config.df_dir) + chart_dir = Path(plot_config.output_chart_dir) + chart_dir.mkdir(exist_ok=True, parents=True) + for ff in plot_config.frame_files: + path = df_dir / ff + df = pd.read_json(path) + print(df) + print(df.columns) + create_chart(df, chart_dir / f"{path.name}.pdf") + + +if __name__ == "__main__": + main() diff --git a/bytelatent/preprocess/__init__.py b/bytelatent/preprocess/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/preprocess/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/preprocess/data_pipeline.py b/bytelatent/preprocess/data_pipeline.py new file mode 100644 index 0000000..7dd5283 --- /dev/null +++ b/bytelatent/preprocess/data_pipeline.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import subprocess +from pathlib import Path + +import luigi + +# CHANGEME: Change this to point to your data +BASE_DIR = Path("datasets") +DATASETS = ["dclm"] +TARGET_DIR = Path("entropy_preprocess") + +SHARD_SCRIPT = """split -C 2500m -d {source} {destination}.shard_""" + + +def list_dataset_shards(dataset: str): + dataset_dir = BASE_DIR / dataset + return list(dataset_dir.glob("*.chunk.*.jsonl")) + + +class ChunkFile(luigi.ExternalTask): + file = luigi.Parameter() + + def output(self): + return luigi.LocalTarget(self.file) + + +class ShardDatasetChunk(luigi.Task): + dataset_name = luigi.Parameter() + chunk_file = luigi.Parameter() + + def _chunk_filename(self): + return Path(self.chunk_file).name + + def requires(self): + return ChunkFile(self.chunk_file) + + def run(self): + destination_dir = TARGET_DIR / str(self.dataset_name) + destination_dir.mkdir(parents=True, exist_ok=True) + destination = destination_dir / self._chunk_filename() + subprocess.check_output( + SHARD_SCRIPT.format(source=str(self.chunk_file), destination=destination), + shell=True, + ) + ( + Path(TARGET_DIR) + / str(self.dataset_name) + / f"{self._chunk_filename()}.shard.COMPLETE" + ).touch() + + def output(self): + return luigi.LocalTarget( + TARGET_DIR + / str(self.dataset_name) + / f"{self._chunk_filename()}.shard.COMPLETE" + ) + + +class ShardDataset(luigi.WrapperTask): + dataset_name = luigi.Parameter() + + def requires(self): + for f in list_dataset_shards(self.dataset_name): + yield ShardDatasetChunk(dataset_name=self.dataset_name, chunk_file=str(f)) + + +class ShardAllDatasets(luigi.WrapperTask): + def requires(self): + for d in DATASETS: + yield ShardDataset(dataset_name=d) + + +if __name__ == "__main__": + luigi.build([ShardAllDatasets()], local_scheduler=True, workers=128) diff --git a/bytelatent/preprocess/parallel_entropies.py b/bytelatent/preprocess/parallel_entropies.py new file mode 100644 index 0000000..ac2dbd4 --- /dev/null +++ b/bytelatent/preprocess/parallel_entropies.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import subprocess +from pathlib import Path + +import submitit +import typer + + +class PreprocessEntropiesJob(submitit.helpers.Checkpointable): + def __init__(self) -> None: + pass + + def __call__(self, shard_file: str, output_filename: str): + subprocess.run( + [ + "python", + "-u", + "-m", + "bytelatent.preprocess.preprocess_entropies", + str(shard_file), + str(output_filename), + ], + check=True, + ) + return True + + +def chunk(items, size): + for i in range(0, len(items), size): + yield items[i : i + size] + + +def main( + job_folder: str, + input_dir: str, + output_dir: str, + qos: str = "explore", + slurm_batch_size: int = 1000, + check_only: bool = False, + wait: bool = False, +): + input_dir = Path(input_dir) + output_dir = Path(output_dir) + shard_files = [ + p for p in input_dir.glob("*.jsonl.shard*") if "COMPLETE" not in p.name + ] + if check_only: + exist = [] + missing = [] + for shard_file in shard_files: + shard_file = Path(shard_file) + complete_file = output_dir / f"{shard_file.name}.arrow.complete" + if complete_file.exists(): + exist.append(complete_file) + else: + missing.append(complete_file) + print("Checked for output files for input_dir=", input_dir) + print("Exist:", len(exist)) + print("Missing:", len(missing)) + print(missing) + return + print("Running parallel job over N files=", len(shard_files)) + print("Input Directory:", input_dir) + print("Output Directory:", output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + executor = submitit.SlurmExecutor(job_folder) + executor.update_parameters( + # 12 hours in minutes + time=60 * 12, + qos=qos, + exclusive="user", + cpus_per_task=4, + num_gpus=1, + mem_per_gpu="80G", + array_parallelism=slurm_batch_size, + ) + + jobs = [] + n_batches = 0 + n_skipped = 0 + n_launched = 0 + for file_batch in chunk(shard_files, slurm_batch_size): + with executor.batch(): + for shard_file in file_batch: + output_filename = Path(output_dir) / f"{shard_file.name}.arrow" + complete_output_filename = ( + Path(output_dir) / f"{shard_file.name}.arrow.complete" + ) + if complete_output_filename.exists(): + n_skipped += 1 + else: + job = executor.submit( + PreprocessEntropiesJob(), str(shard_file), str(output_filename) + ) + n_launched += 1 + jobs.append(job) + n_batches += 1 + print("launched array jobs n=", n_launched) + print("skipped (completed) array jobs n=", n_skipped) + print("number of slurm batches=", n_batches) + if wait: + output = [job.result() for job in jobs] + assert all(output) + + +if __name__ == "__main__": + typer.run(main) diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py new file mode 100644 index 0000000..20d1e0c --- /dev/null +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import time +from pathlib import Path + +import numpy as np +import pyarrow as pa +import torch +import typer +from rich.progress import Progress, TextColumn + +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator + + +def main( + input_file: str, + output_file: str, + patching_device: str = "cuda", + log_step: int = 10_000, + entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + dry_run: bool = False, +): + # TODO: Modify this to work with the new code + raise NotImplementedError() + iterator = ArrowFileIterator( + file_path=input_file, + worker_id=0, + num_workers=1, + ) + tokenization_mode = "bytes" + print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") + print("Loading entropy model", entropy_model_checkpoint_dir) + if dry_run: + return + entropy_model = load_entropy_model( + entropy_model_checkpoint_dir, device=patching_device + ) + entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") + patching_batch_size = 32 + print("Creating tokenizer") + tokenizer = Tokenizer( + model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", + tokenization_mode=tokenization_mode, + # BYTE_UNITS + vocab_size_unit_1=256, + bos=True, + eos=True, + bpe_delim=False, + # This isn't used, just stores a reference for other calls we don't use + patcher=None, + ) + step = 0 + print("starting") + start_time = time.time() + patch_time = 0 + entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False) + sample_id_field = pa.field("sample_id", pa.string(), nullable=False) + text_field = pa.field("text", pa.string(), nullable=False) + schema = pa.schema([sample_id_field, text_field, entropy_field]) + arrow_batch_size = 1_000 + + try: + with pa.OSFile(output_file, "wb") as sink: + with pa.ipc.new_file(sink, schema) as writer: + id_buffer = [] + entropies_buffer = [] + text_buffer = [] + with Progress( + *Progress.get_default_columns(), + TextColumn("Completed: {task.completed}"), + ) as progress: + task = progress.add_task( + "[green]Calculating entropies...", total=None + ) + for doc in iterator: + sample_id = get_id_from_doc(doc) + + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError( + f"Could not find a text key from: {doc.keys()}" + ) + tokens = torch.tensor(tokenizer.encode(text)) + patch_start = time.time() + scores = calculate_entropies( + tokens, + entropy_model, + patching_batch_size, + patching_device, + ) + entropies_buffer.append( + np.array(scores.tolist(), dtype=np.float16) + ) + id_buffer.append(sample_id) + text_buffer.append(text) + if len(entropies_buffer) == arrow_batch_size: + batch = pa.record_batch( + { + "entropies": entropies_buffer, + "sample_id": id_buffer, + "text": text_buffer, + }, + schema, + ) + writer.write(batch) + entropies_buffer = [] + id_buffer = [] + text_buffer = [] + patch_time += time.time() - patch_start + step += 1 + if step % log_step == 0: + print("Completed steps:", step) + progress.update(task, advance=1) + if len(entropies_buffer) > 0: + # Write last things + batch = pa.record_batch( + { + "entropies": entropies_buffer, + "sample_id": id_buffer, + "text": text_buffer, + }, + schema, + ) + writer.write(batch) + entropies_buffer = [] + id_buffer = [] + text_buffer = [] + Path(f"{output_file}.complete").touch() + except: + Path(output_file).unlink(missing_ok=True) + raise + elapsed = time.time() - start_time + print("steps", step) + print("done in:", elapsed) + + +if __name__ == "__main__": + typer.run(main) diff --git a/bytelatent/probe.py b/bytelatent/probe.py new file mode 100644 index 0000000..7bdc03c --- /dev/null +++ b/bytelatent/probe.py @@ -0,0 +1,694 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This file from the xFormers repo is just a example of how to implement +# probing of the activations of a model, without changing anything. +# By default, the linear inputs/outputs/gradients are logged, as well as +# the attention logits+entropy. It is possible to log an additional tensor, eg: +# x = log_stats(x, "name") +# +# Known limitations: +# * Only a subset of the attention biases is supported +# * Torch-compile is disabled automatically when this is enabled +# * Only tested with bf16/f16/f32 datatypes + +import contextlib +import functools +import json +import math +import os +import uuid +from collections import defaultdict +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + checkpoint_wrapper, +) +from torch.fx.operator_schemas import normalize_function +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map +from torch.utils.module_tracker import ModuleTracker +from xformers.ops import fmha + + +@torch.library.custom_op("torchprobe::log", mutates_args=(), device_types=None) +def _log(x: torch.Tensor, name: str, uid: str) -> None: + pass + + +@_log.register_fake +def _log_fake(x: torch.Tensor, name: str, uid: str) -> None: + pass + + +class _LogStats(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, name: str): + uid = str(uuid.uuid4()) + torch.ops.torchprobe.log(x, name, uid) + ctx.name = name + ctx.uid = uid + return x + + @staticmethod + def backward(ctx, grad: torch.Tensor): + torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid) + return grad, None + + +_PROBING_ENABLED = False + + +def log_stats(x: torch.Tensor, name: str) -> torch.Tensor: + if not _PROBING_ENABLED: + return x + return _LogStats.apply(x, name) + + +QUANTILES = [ + 0.0000001, + 0.000001, + 0.00001, + 0.0001, + 0.001, + 0.01, + 0.05, + 0.1, + 0.3, + 0.5, + 0.7, + 0.9, + 0.95, + 0.99, + 0.999, + 0.9999, + 0.99999, + 0.999999, + 0.9999999, +] + + +@functools.cache +def _get_quantiles(device: torch.device, dtype) -> torch.Tensor: + return torch.tensor(QUANTILES, device=device, dtype=dtype) + + +def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]: + if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]: + return {} + x = x_.flatten() + if remove_inf: + x = x[x.abs() < float("inf")] + if x.dtype is not torch.double: + x = x.float() + xabs = x.abs() + quantiles = _get_quantiles(x.device, x.dtype) + mean = x.mean() + std = x.std() + return { + "shape": tuple(x_.shape), + "mean": mean, + "std": std, + "skew": (((x - mean) / std) ** 3).double().mean(), + "kurtosis": (((x - mean) / std) ** 4).double().mean(), + "abs.mean": xabs.mean(), + "max": x.max(), + "min": x.min(), + # Note: `quantile` takes at most 2**24 elements, see + # https://github.com/pytorch/pytorch/issues/64947 + "quantiles": torch.quantile(x[: 2**24], quantiles), + } + + +def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None: + assert logits.ndim == 4 + logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf + + +def _mask_attn_logits( + logits: torch.Tensor, + q_idx: List[int], + *, + causal: bool, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert logits.dtype is torch.float32 + # Handle BlockDiagonalMask + if cu_seqlens_q is not None: + assert cu_seqlens_k is not None + # Expect BHMqMkv + assert logits.ndim == 4, logits.shape + qs = cu_seqlens_q.tolist() + ks = cu_seqlens_k.tolist() + q_batchid = [] + k_batchid = [-2] * logits.shape[-1] + q_idx_i = 0 + for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])): + for k in range(k0, k1): + k_batchid[k] = bid + while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1: + q_batchid.append(bid) + if causal: + _mask_attn_causal_inplace( + logits[:, :, q_idx_i : q_idx_i + 1, k0:k1], + q_idx[q_idx_i] - q0, + q1 - q0, + k1 - k0, + ) + q_idx_i += 1 + mask_out = ( + torch.tensor(q_batchid, device=logits.device)[None, None, :, None] + != torch.tensor(k_batchid, device=logits.device)[None, None, None, :] + ) + logits[mask_out.expand_as(logits)] = -math.inf + assert q_idx_i == len(q_idx) + elif causal: + for q_idx_i in range(len(q_idx)): + _mask_attn_causal_inplace( + logits[:, :, q_idx_i : q_idx_i + 1, :], + q_idx[q_idx_i], + logits.shape[2], + logits.shape[3], + ) + return logits + + +def _attn_queries_subset(num_queries: int) -> List[int]: + return list(range(0, num_queries, max(1, num_queries // 128))) + + +@torch.no_grad() +def _compute_attn_stats_sdpa( + probe, + path: str, + # supports arguments both cudnn + flash backends + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask=None, + attn_bias=None, + dropout_p=0.0, + is_causal=False, + scale=None, + compute_log_sumexp=True, + return_debug_mask=False, + **kwargs, +): + if scale is None: + scale = 1 / (query.shape[-1] ** 0.5) + # Filter-out not supported cases + if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs: + probe.store[f"{path}::attn"] = { + "query.shape": tuple(query.shape), + "key.shape": tuple(key.shape), + "value.shape": tuple(value.shape), + "attn_mask": attn_mask.shape if attn_mask is not None else None, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "unk_kwargs": list(kwargs.keys()), + } + return + # Take a subset of the queries and compute the logits + query_s = _attn_queries_subset(query.shape[-2]) + logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale + logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal) + p = logits.float().softmax(-1) + masked_logsoft = logits.log_softmax(-1).where( + (logits > -math.inf), torch.zeros_like(logits) + ) + entropy = -(p * masked_logsoft).sum(-1) + probe.log_tensor(f"{path}::attn_entropy", entropy) + probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) + + +@torch.no_grad() +def _compute_attn_stats_flash( + probe, + path: str, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + p: float, + softmax_scale: float, + is_causal: bool, + window_left: int, + window_right: int, + return_softmax: bool, + block_tables: Optional[torch.Tensor], + unpadded_lse: bool = False, +) -> None: + # Filter-out not supported cases + if ( + seqused_k is not None + or p != 0.0 + or window_left >= 0 + or window_right >= 0 + or block_tables is not None + ): + probe.store[f"{path}::attn"] = { + "query.shape": tuple(query.shape), + "key.shape": tuple(key.shape), + "value.shape": tuple(value.shape), + "op": "flash", + } + return + + if cu_seqlens_q is not None: + assert query.ndim == 3, query.shape + query, key, value = query[None], key[None], value[None] + assert query.ndim == 4, query.shape + + # Take a subset of the queries and compute the logits + query_s = _attn_queries_subset(query.shape[1]) + logits = ( + query[:, query_s].transpose(1, 2) + @ key.transpose(1, 2).transpose(-1, -2) + * softmax_scale + ) + logits = _mask_attn_logits( + logits.float(), + query_s, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + causal=is_causal, + ) + p = logits.float().softmax(-1) + masked_logsoft = logits.log_softmax(-1).where( + (logits > -math.inf), torch.zeros_like(logits) + ) + entropy = -(p * masked_logsoft).sum(-1) + probe.log_tensor(f"{path}::attn_entropy", entropy) + probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) + + +def _tensors_to_python(x): + if not isinstance(x, torch.Tensor): + return x + return x.tolist() + + +# class syntax +class LinearBwType(Enum): + DW = 1 + DX = 2 + UNKNOWN = 3 + + +class AutoProbeD(TorchDispatchMode): + def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None: + self.write_file = Path(write_file) if write_file is not None else None + self.write_tensors_tmpdir: Optional[Path] = None + self.compile_disabler = TorchCompileDisabler(module) + self.mod_tracker = ModuleTracker() + self.count_per_path: Dict[str, int] = defaultdict(int) + self.store: Dict[str, Dict[str, Any]] = {} + self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {} + self.uid_to_path: Dict[str, str] = {} + self.metadata: Any = None + self.enabled = False + self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0"))) + + def __enter__(self): + global _PROBING_ENABLED + assert not self.enabled, "Entered probe twice" + self.compile_disabler.__enter__() + self.mod_tracker.__enter__() + super().__enter__() + self.enabled = True + _PROBING_ENABLED = True + # self._setup_tensors_logging() + return self + + def __exit__(self, *args) -> None: + global _PROBING_ENABLED + assert self.enabled, "Exiting probe without entering it" + super().__exit__(*args) + self.mod_tracker.__exit__(*args) + self.compile_disabler.__exit__(*args) + self._flush_and_clear() + _PROBING_ENABLED = False + self.enabled = False + + def _setup_tensors_logging(self): + if self.write_file is not None: + self.write_file.parent.mkdir(exist_ok=True) + self.write_tensors_tmpdir = ( + self.write_file.parent + / f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}" + ) + self.write_tensors_tmpdir.mkdir(exist_ok=True) + + def _flush_and_clear(self) -> None: + if self.write_file is not None: + dump_data = tree_map(_tensors_to_python, self.store) + with self.write_file.open("a") as fd: + json.dump( + { + "data": dump_data, + "meta": self.metadata, + "version": 2, + "quantiles": QUANTILES, + }, + fd, + ) + fd.write("\n") + if self.write_tensors_tmpdir is not None: + assert self.write_file is not None + dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump" + dump_dir.mkdir(exist_ok=True) + dir_name = "" + if "it" in self.metadata: + dir_name = f"it{int(self.metadata['it']):010}" + if dir_name == "" or (dump_dir / dir_name).exists(): + num_files = len(list(dump_dir.glob(f"{dir_name}v*"))) + dir_name = f"{dir_name}v{num_files}" + dump_dir = dump_dir / dir_name + assert not dump_dir.exists() + self.write_tensors_tmpdir.rename(dump_dir) + self.write_tensors_tmpdir = None + self.store.clear() + self.count_per_path.clear() + self.uid_to_path.clear() + + def _find_bw_path_and_type( + self, path: str, out: torch.Tensor, args + ) -> Tuple[str, LinearBwType]: + """ + We are in the BW pass, and process a GEMM. + Let's figure out: + (1) The path for the FW pass (might differ in case of ModuleTracker bug) + (2) The type of BW pass (eg `dw` or `dx`) + """ + + def _is_path_correct_dw(path: str) -> bool: + # dW.t = dY.t @ X + in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] + return out.shape == (w_shape[1], w_shape[0]) and torch.allclose( + input_sm, args[1][:4, :4] + ) + + def _is_path_correct_dx(path: str) -> bool: + # dX = dY @ W.t + in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] + return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4]) + + if path in self.linear_data: + if _is_path_correct_dw(path): + return path, LinearBwType.DW + if _is_path_correct_dx(path): + return path, LinearBwType.DX + for candidate_path in self.mod_tracker.parents: + if candidate_path not in self.linear_data: + continue + if _is_path_correct_dw(candidate_path): + return candidate_path, LinearBwType.DW + if _is_path_correct_dx(candidate_path): + return candidate_path, LinearBwType.DX + return path, LinearBwType.UNKNOWN + + def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None: + self.store[name] = _get_stats(x, **kwargs) + if self.write_tensors_tmpdir is not None: + name_safe = name.replace("::", "__").replace("/", "") + torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl") + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + path = None + # Find longest path + for p in self.mod_tracker.parents: + if p == "Global": + continue + if path is None or len(p) > len(path): + path = p + if path is None: + path = "Global" + path = path.replace("._checkpoint_wrapped_module", "") + out = func(*args, **kwargs) + + # Handle linear layers + if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]: + weight: torch.Tensor + input: torch.Tensor + if not self.mod_tracker.is_bw: + # (technically, weight is transposed) + if func._overloadpacket == torch.ops.aten.addmm: + _bias, input, weight = args[:3] + else: + assert func._overloadpacket == torch.ops.aten.mm + input, weight = args[:2] + self.log_tensor(f"{path}::in", input) + self.log_tensor(f"{path}::w", weight) + self.log_tensor(f"{path}::out", out) + self.linear_data[path] = ( + input.shape, + weight.shape, + out.shape, + input[:4, :4].clone(), + weight[:4, :4].T.clone(), + ) + elif func._overloadpacket == torch.ops.aten.mm: + # XXX: Try to find the actual path for the linear layer + # This is messed with with Francisco's FSDP sometimes + new_path, bwtype = self._find_bw_path_and_type(path, out, args) + if new_path != path: + if self.verbose: + print(f"E: Fixing path `{path}` -> `{new_path}") + path = new_path + + if bwtype == LinearBwType.DW: + # dW.t = dY.t @ X + self.log_tensor(f"{path}::w.g", out) + elif bwtype == LinearBwType.DX: + # dX = dY @ W.t + self.log_tensor(f"{path}::in.g", out) + self.log_tensor(f"{path}::out.g", args[0]) + elif func._overloadpacket in [ + torch.ops.aten._scaled_dot_product_flash_attention, + torch.ops.aten._scaled_dot_product_cudnn_attention, + ]: + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + _compute_attn_stats_sdpa(self, path, **kwargs) + elif func._overloadpacket == fmha.flash.FwOp.OPERATOR: + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + _compute_attn_stats_flash(self, path, **kwargs) + elif func._overloadpacket == torch.ops.torchprobe.log: + uid = args[2] + path = self.uid_to_path.setdefault(uid, path) + self.log_tensor(f"{path}::{args[1]}", args[0]) + if self.verbose: + print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}") + return out + + +def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None: + if module._compiled_call_impl is not None: + out.append(module) + for c in module.children(): + _find_all_submodules_compiled(out, module=c) + + +class TorchCompileDisabler: + def __init__(self, module: nn.Module) -> None: + self.module = module + self.submodules_compiled: List[nn.Module] = [] + self.compiled_call_impl: List[Any] = [] + self.disable_compile = torch.compiler.disable() + torch._dynamo.config.raise_on_ctx_manager_usage = False # type: ignore + + def __enter__(self) -> None: + # Remove all `_compiled_call_impl` attributes to effectively + # "undo" compilation + self.submodules_compiled.clear() + _find_all_submodules_compiled(self.submodules_compiled, self.module) + self.compiled_call_impl = [ + m._compiled_call_impl for m in self.submodules_compiled + ] + for m in self.submodules_compiled: + m._compiled_call_impl = None + self.disable_compile.__enter__() # type: ignore + + def __exit__(self, *args) -> None: + self.disable_compile.__exit__(*args) # type: ignore + for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl): + m._compiled_call_impl = c_impl + self.compiled_call_impl = [] + + +Probe = AutoProbeD + +# EXAMPLE USAGE +d = 512 +seqlen = 4 +bs = 2 + + +class Attention1(nn.Module): + def forward(self, x): + attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask() + return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape( + [x.shape[0], seqlen, -1] + ) + + +class Attention2(nn.Module): + def forward(self, x): + attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens( + [seqlen] * bs + ).make_causal() + xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]]) + return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape( + [x.shape[0], seqlen, -1] + ) + + +class AttentionSDPA(nn.Module): + def __init__(self): + super().__init__() + self.wo = nn.Linear(d, d) + + def forward(self, x): + x = x.transpose(1, 2) + return self.wo( + F.scaled_dot_product_attention(x, x, x) + .transpose(1, 2) + .reshape([x.shape[0], seqlen, -1]) + ) + + +class AttentionSDPAFlash(AttentionSDPA): + def forward(self, x): + x = x.transpose(1, 2) + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + return self.wo( + F.scaled_dot_product_attention(x, x, x) + .transpose(1, 2) + .reshape([x.shape[0], seqlen, -1]) + ) + + +class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.head = nn.Linear(d, 16) + self.trunk = nn.Sequential( + nn.Linear(d, d), + nn.Linear(d, d), + ) + self.q_proj = nn.Linear(d, d, bias=False) + self.trunk.compile() + self.attn1 = Attention1() + self.attn2 = Attention2() + self.attnSDPA = AttentionSDPA() + self.attnSDPAflash = AttentionSDPAFlash() + + def forward(self, x): + B, nHeads, D = x.shape[0], d // 64, 64 + x = self.q_proj(x).reshape([B, seqlen, nHeads, D]) + x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x) + x = log_stats(x, "attns_out") + return self.head(self.trunk(x)) + + +def test_masking() -> None: + q_seqlen = [1, 1, 14, 12] + kv_seqlen = [2, 2, 14, 18] + attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen, kv_seqlen + ).make_causal_from_bottomright() + logits = torch.randn( + [1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda" + ) + bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device) + logits_masked = logits.clone() + _mask_attn_logits( + logits_masked, + list(range(logits.shape[2])), + causal=True, + cu_seqlens_q=attn_bias.q_seqinfo.seqstart, + cu_seqlens_k=attn_bias.k_seqinfo.seqstart, + ) + assert (logits + bias == logits_masked).all().item() + + +def test_toy_model() -> None: + # Test masking + kw = dict(device="cuda", dtype=torch.float16) + x = torch.randn([bs, seqlen, d], **kw) + m = Model() + m.head = checkpoint_wrapper( + m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False + ) + m.to(**kw) + m.compile() + optim = torch.optim.SGD(m.parameters(), lr=0.0) + probe = AutoProbeD(m, "./probe.json") + + for i in range(4): + with contextlib.ExitStack() as stack: + print(f"########### STEP {i}") + if i % 4 == 1: + stack.enter_context(probe) + probe.metadata = {"it": i} + y = m(x) + g = torch.randn_like(y) + y.backward(g) + if i % 4 == 1: + assert probe.enabled + # Make sure we registered all linears + print(list(probe.store.keys())) + for key in [ + "Model::attns_out", + "Model::attns_out.g", + "Model.attn1::attn_logits", + "Model.attn2::attn_logits", + "Model.attnSDPA::attn_logits", + "Model.attnSDPAflash::attn_logits", + "Model.head::w", + "Model.head::w.g", + "Model.head::in", + "Model.head::in.g", + "Model.head::out", + "Model.head::out.g", + "Model.trunk.0::in", + "Model.trunk.1::in", + ]: + assert key in probe.store, f"Missing key: '{key}'" + # .. and that the values are correct + for key, tensor in [ + ("Model.head::w", m.head.weight), + ("Model.head::w.g", m.head.weight.grad), + ("Model.q_proj::in", x), + ("Model.q_proj::w.g", m.q_proj.weight.grad), + ("Model.head::out", y), + ("Model.head::out.g", g), + ]: + assert key in probe.store, f"Missing key: '{key}'" + assert torch.allclose( + probe.store[key]["abs.mean"], tensor.float().abs().mean() + ), f"'{key}' mismatches" + # Check we don't have `nans` + for key, value in probe.store.items(): + if "abs.mean" in value: + assert math.isfinite( + value["abs.mean"].item() + ), f"Inf/Nan for {key}" + optim.step() + optim.zero_grad() diff --git a/bytelatent/profiling.py b/bytelatent/profiling.py new file mode 100644 index 0000000..da3c90d --- /dev/null +++ b/bytelatent/profiling.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import contextlib +import logging +import os +from pathlib import Path + +import torch.distributed +import wandb +import xformers.profiler +from pydantic import BaseModel +from torch.profiler.profiler import profile +from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler + +from bytelatent.distributed import get_is_master + + +class ProfilerArgs(BaseModel): + run: bool = False + trace_folder: str = "profiling" + mem_warmup: int = 100 + mem_steps: int = 2 + profile_warmup: int = 102 + profile_steps: int = 2 + + +logger = logging.getLogger() + + +def perfetto_to_html(json_file, html_file): + import gzip + import string + + import viztracer + + root = os.path.dirname(viztracer.__file__) + sub = {} + json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file) + with open( + os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8" + ) as f: + tmpl = f.read() + with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f: + sub["trace_viewer_full"] = f.read() + with json_file as j: + content = j.read() + if isinstance(content, bytes): + content = content.decode("utf-8") + sub["json_data"] = content.replace("", "<\\/script>") # type: ignore + with open(html_file, "w+", encoding="utf-8") as output_file: + output_file.write(string.Template(tmpl).substitute(sub)) + + +class PyTorchProfilerWandb(PyTorchProfiler): + def __init__(self, main_profiler) -> None: + self.main_profiler = main_profiler + self.num_steps = 0 + self.pytorch_profiler = torch.profiler.profile( + on_trace_ready=self._on_trace, + profile_memory=True, + record_shapes=True, + # With stack gives huge profile traces + # and bugs out because of some non ascii + # character somewhere in pytorch + with_stack=False, + with_flops=True, + activities=self.ACTIVITIES, + ) + + def _analyze_trace(self, prof: profile): + logger.info("Begin analyze trace") + super()._analyze_trace(prof) + logger.info("End analyze trace") + + def _on_trace(self, prof: torch.profiler.profiler.profile) -> None: + super()._on_trace(prof) + if get_is_master() and wandb.run is not None: + filename = list( + Path(self.main_profiler.output_dir).glob( + "profile_CPU_CUDA*/*.pt.trace.json*" + ) + )[0] + html_path = str(filename).replace(".json", ".html") + perfetto_to_html(filename, html_path) + wandb.log({"profile_trace": wandb.Html(html_path)}) + + +class MemSnapshotsProfilerWandb(MemSnapshotsProfiler): + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + if get_is_master() and wandb.run is not None: + filename = list( + Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html") + )[0] + wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)}) + + +@contextlib.contextmanager +def maybe_run_profiler(dump_dir, module, config: ProfilerArgs): + # get user defined profiler settings + + if config.run: + trace_dir = os.path.join(dump_dir, config.trace_folder) + + logger.info(f"Profiling active. Traces will be saved at {trace_dir}") + + if get_is_master() and not os.path.exists(trace_dir): + os.makedirs(trace_dir) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + with xformers.profiler.profile( + output_dir=trace_dir, + module=module, + schedule=[ + ( + MemSnapshotsProfilerWandb, + config.mem_warmup, + config.mem_warmup + config.mem_steps, + ), + ( + PyTorchProfilerWandb, + config.profile_warmup, + config.profile_warmup + config.profile_steps, + ), + ], + ) as profiler: + yield profiler + + else: + torch_profiler = contextlib.nullcontext() + yield None diff --git a/bytelatent/stool.py b/bytelatent/stool.py new file mode 100644 index 0000000..965f4cb --- /dev/null +++ b/bytelatent/stool.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +import os +import shutil +import subprocess +from dataclasses import dataclass +from typing import Any, Dict + +from omegaconf import OmegaConf + + +@dataclass +class StoolArgs: + config: Any = None + launcher: str = "sbatch" # Can be sbatch or bash if already in salloc + script: str = "apps.main.train" # The script to run. + copy_code: bool = True # Wether to copy code to dump dir + dirs_exists_ok: bool = ( + False # Wether to copy new code and config and run regardless that dir exists + ) + override: bool = False # Wether to delete dump dir and restart + nodes: int = -1 # The number of nodes to run the job on. + ngpu: int = 8 # The number of GPUs required per node. + ncpu: int = 16 # The number of CPUs allocated per GPU. + mem: str = "" # The amount of memory to allocate. + anaconda: str = "default" # The path to the anaconda environment. + constraint: str = "" # The constraint on the nodes. + exclude: str = "" # The nodes to exclude. + time: int = -1 # The time limit of the job (in minutes). + account: str = "" + qos: str = "" + partition: str = "learn" + stdout: bool = False + + +SBATCH_COMMAND = """#!/bin/bash + +{exclude} +{qos} +{account} +{constraint} +#SBATCH --job-name={name} +#SBATCH --nodes={nodes} +#SBATCH --gres=gpu:{ngpus} +#SBATCH --cpus-per-gpu={ncpu} +#SBATCH --time={time} +#SBATCH --partition={partition} +#SBATCH --mem={mem} + +#SBATCH --output={dump_dir}/logs/%j/%j.stdout +#SBATCH --error={dump_dir}/logs/%j/%j.stderr + +#SBATCH --open-mode=append +#SBATCH --signal=USR2@120 +#SBATCH --distribution=block + +# Mimic the effect of "conda init", which doesn't work for scripts +eval "$({conda_exe} shell.bash hook)" +source activate {conda_env_path} + +{go_to_code_dir} + +export OMP_NUM_THREADS=1 +export LAUNCH_WITH="SBATCH" +export DUMP_DIR={dump_dir} +srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml +""" + + +def copy_dir(input_dir: str, output_dir: str) -> None: + print(f"Copying : {input_dir}\n" f"to : {output_dir} ...") + assert os.path.isdir(input_dir), f"{input_dir} is not a directory" + assert os.path.isdir(output_dir), f"{output_dir} is not a directory" + rsync_cmd = ( + f"rsync -arm --copy-links " + f"--include '**/' " + f"--include '*.py' " + f"--exclude='*' " + f"{input_dir}/ {output_dir}" + ) + print(f"Copying command: {rsync_cmd}") + subprocess.call([rsync_cmd], shell=True) + print("Copy done.") + + +def retrieve_max_time_per_partition() -> Dict[str, int]: + # retrieve partition max times (a bit slow) + + sinfo = json.loads(subprocess.check_output("sinfo --json", shell=True))["sinfo"] + max_times: Dict[str, int] = {} + + for info in sinfo: + if info["partition"]["maximums"]["time"]["infinite"]: + max_times[info["partition"]["name"]] = 14 * 24 * 60 # 14 days + else: + max_times[info["partition"]["name"]] = info["partition"]["maximums"][ + "time" + ][ + "number" + ] # in minutes + + return max_times + + +def validate_args(args) -> None: + # Set maximum time limit if not specified + if args.time == -1: + max_times = retrieve_max_time_per_partition() + args.time = max_times.get( + args.partition, 3 * 24 * 60 + ) # Default to 3 days if not found + print( + f"No time limit specified, using max time for partitions: {args.time} minutes" + ) + + if args.constraint: + args.constraint = f"#SBATCH --constraint={args.constraint}" + + if args.account: + args.account = f"#SBATCH --account={args.account}" + + if args.qos: + args.qos = f"#SBATCH --qos={args.qos}" + + if getattr(args, "exclude", ""): + args.exclude = f"#SBATCH --exclude={args.exclude}" + + if hasattr(args, "anaconda") and args.anaconda: + if args.anaconda == "default": + args.anaconda = ( + subprocess.check_output("which python", shell=True) + .decode("ascii") + .strip() + ) + else: + args.anaconda = f"{args.anaconda}/bin/python" + assert os.path.isfile(args.anaconda) + + args.mem = args.mem or "0" + + assert args.partition + assert args.ngpu > 0 + assert args.ncpu > 0 + assert args.nodes > 0 + assert args.time > 0 + assert args.partition + + +def launch_job(args: StoolArgs): + # Set up args default and validate them depending on the cluster or partition requested + validate_args(args) + dump_dir = args.config["dump_dir"] + job_name = args.config["name"] + print("Creating directories...") + os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override) + if args.override: + confirm = input( + f"Are you sure you want to delete the directory '{dump_dir}'? This action cannot be undone. (yes/no): " + ) + if confirm.lower() == "yes": + shutil.rmtree(dump_dir) + print(f"Directory '{dump_dir}' has been deleted.") + else: + print("Operation cancelled.") + return + if args.copy_code: + os.makedirs(f"{dump_dir}/code", exist_ok=args.dirs_exists_ok) + print("Copying code ...") + copy_dir(os.getcwd(), f"{dump_dir}/code") + + print("Saving config file ...") + with open(f"{dump_dir}/base_config.yaml", "w") as cfg: + cfg.write(OmegaConf.to_yaml(args.config)) + + conda_exe = os.environ.get("CONDA_EXE", "conda") + conda_env_path = os.path.dirname(os.path.dirname(args.anaconda)) + log_output = ( + "-o $DUMP_DIR/logs/%j/%j_%t.out -e $DUMP_DIR/logs/%j/%j_%t.err" + if not args.stdout + else "" + ) + sbatch = SBATCH_COMMAND.format( + name=job_name, + script=args.script, + dump_dir=dump_dir, + nodes=args.nodes, + tasks=args.nodes * args.ngpu, + nodes_per_run=args.nodes, + ngpus=args.ngpu, + ncpu=args.ncpu, + mem=args.mem, + qos=args.qos, + account=args.account, + constraint=args.constraint, + exclude=args.exclude, + time=args.time, + partition=args.partition, + conda_exe=conda_exe, + conda_env_path=conda_env_path, + log_output=log_output, + go_to_code_dir=f"cd {dump_dir}/code/" if args.copy_code else "", + ) + + print("Writing sbatch command ...") + with open(f"{dump_dir}/submit.slurm", "w") as f: + f.write(sbatch) + + print("Submitting job ...") + os.system(f"{args.launcher} {dump_dir}/submit.slurm") + + print("Done.") + + +if __name__ == "__main__": + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + mode: LMTransformerArgs + + @dataclass + class LMTransformerArgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgs + or just name=tictac for top level attributes. + """ + raise NotImplementedError("Update this to blt code") + args = OmegaConf.from_cli() + args.config = OmegaConf.load(args.config) + args = dataclass_from_dict(StoolArgs, args) + launch_job(args) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py new file mode 100644 index 0000000..73ad9f7 --- /dev/null +++ b/bytelatent/test_blt.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +from dataclasses import replace + +import numpy as np +import pytest +import torch + +from bytelatent.constants import BLT_DATA +from bytelatent.data.data_types import Batch +from bytelatent.data.ngram_processor import NgramProcessor +from bytelatent.model.blt import ( + ByteLatentTransformer, + ByteLatentTransformerArgs, + EmbeddingType, + compute_hash_embeddings, + create_global_transformer, + create_local_decoder, + create_local_encoder, + cross_attn_mask, + decoder_patch_ids_from_lengths, + get_blt_input, + init_embeddings, + patch_ids_from_lengths, +) +from bytelatent.model.transformer import CrossAttention +from bytelatent.model.utils import create_causal_mask +from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.train import compute_loss + + +def batch_to_tensors_and_gpu(batch): + x = torch.from_numpy(batch.x) + y = torch.from_numpy(batch.y) + mask = None if batch.mask is None else torch.from_numpy(batch.mask) + patch_lengths = ( + None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths) + ) + ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids) + + if torch.cuda.is_available(): + x = x.cuda() + y = y.cuda() + if mask is not None: + mask = mask.cuda() + if patch_lengths is not None: + patch_lengths = patch_lengths.cuda() + if ngram_ids is not None: + ngram_ids = ngram_ids.cuda() + return x, y, mask, patch_lengths, ngram_ids + + +def fake_batch(): + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + del batch_dict["x2"] + del batch_dict["y2"] + del batch_dict["src_names"] + return Batch(**batch_dict) + + +def create_args(cross_attention=False): + transformer_args = ByteLatentTransformerArgs( + # Base args provided + n_heads=8, + dim=512, + vocab_size=260, + # Additional args from command line + dim_token=256, + patch_size=6, + tokenization_mode="bytes", + patching_mode="space", + tie_local_encoder_decoder_logits=False, + data_loader_patching=True, + max_encoder_seq_length=12288, + pad_to_max_length=True, + encoder_lm_loss=False, + patching_threshold=3.1439168453216553, + encoder_hash_byte_group_size=[4], + encoder_hash_byte_group_vocab=50002, + encoder_hash_byte_group_nb_functions=3, + cross_attn_encoder=cross_attention, # True, + cross_attn_decoder=cross_attention, # True, + cross_attn_window_encoder=512, + cross_attn_window_decoder=512, + dim_local_encoder=256, + dim_local_decoder=256, + cross_attn_k=8, + cross_attn_nheads=4, + cross_attn_all_layers_decoder=True, + cross_attn_all_layers_encoder=True, + cross_attn_use_flex_attention=True, + cross_attn_init_by_pooling=True, + log_patch_lengths=True, + non_linearity="swiglu", + use_rope=True, + recompute_fc1_out=False, + recompute_fc3_out=False, + recompute_attn=False, + custom_bwd=False, + layer_ckpt="none", + efficient_attn="sdpa", + patch_only_encoder=False, + patch_only_decoder=False, + use_local_encoder_transformer=True, + init_use_gaussian=True, + init_use_depth="current", + attn_bias_type="block_causal", + alpha_depth="disabled", + max_length=256, + local_attention_window_len=512, + max_seqlen=12288, + downsampling_by_pooling="max", + ) + return transformer_args + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +class TestByteLatentTransformer: + def test_local_encoder(self): + args = create_args() + device = torch.device("cuda") + local_encoder = create_local_encoder(args).to(device) + + batch = fake_batch() + tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) + + local_encoder_tokens, _, _ = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=0, + patch_size=local_encoder.patch_size, + boe_id=local_encoder.boe_id, + ) + + patch_ids = patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + + encoder_hash_tok_embedding = init_embeddings( + args, + EmbeddingType.HASH_TOK, + local_encoder_dim=local_encoder.dim, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + ).to(device) + + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=local_encoder, + encoder_hash_tok_embedding=encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab, + ) + + reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt") + reference_tokens = torch.load(reference_path).to(device) + torch.testing.assert_close( + local_encoder_tokens, + reference_tokens, + msg="Generated tokens don't match reference tokens", + ) + + (h_encoder, h_cross), cache_encoder = local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=None, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + assert h_encoder is not None + assert h_cross is None + assert cache_encoder is None + + expected_shape = ( + local_encoder_tokens.shape[0], + local_encoder_tokens.shape[1], + local_encoder.dim, + ) + assert h_encoder.shape == expected_shape + + def test_local_encoder_cross_attention(self): + args = create_args(cross_attention=True) + device = torch.device("cuda") + local_encoder = create_local_encoder(args).to(device) + + batch = fake_batch() + tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) + + local_encoder_tokens, _, _ = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=0, + patch_size=local_encoder.patch_size, + boe_id=local_encoder.boe_id, + ) + + patch_ids = patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + + encoder_hash_tok_embedding = init_embeddings( + args, + EmbeddingType.HASH_TOK, + local_encoder_dim=local_encoder.dim, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + ).to(device) + + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + local_encoder_tokens.shape[-1], + patches_as_queries=True, + cross_attn_k=args.cross_attn_k, + window=args.cross_attn_window_encoder, + block_mask=True, + ) + + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=local_encoder, + encoder_hash_tok_embedding=encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab, + ) + (h_encoder, h_cross), cache_encoder = local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + assert h_encoder is not None + assert h_cross is not None + assert cache_encoder is None + expected_shape = ( + local_encoder_tokens.shape[0], + local_encoder_tokens.shape[1], + local_encoder.dim, + ) + assert h_encoder.shape == expected_shape + assert h_cross.shape == (2, 2048, local_encoder.dim) + + def test_local_decoder_cross_attention(self): + args = create_args(cross_attention=True) + device = torch.device("cuda") + local_decoder = create_local_decoder(args).to(device) + + test_files = { + "dec_embeds": "dec_embeds.pt", + "decoder_tokens": "local_decoder_tokens.pt", + "patch_embeds": "decoder_patch_cross_embeds.pt", + } + batch = fake_batch() + _, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) + + tensors = { + name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) + for name, filename in test_files.items() + } + decoder_patch_ids = decoder_patch_ids_from_lengths( + patch_lengths, 0, tensors["decoder_tokens"].shape[-1] + ) + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + tensors["decoder_tokens"].shape[-1], + patches_as_queries=False, + cross_attn_k=args.cross_attn_k, + window=args.cross_attn_window_decoder, + block_mask=True, + ) + output, _ = local_decoder( + embeds=tensors["dec_embeds"], + patch_embeds=tensors["patch_embeds"], + tokens=tensors["decoder_tokens"], + cross_mask=cross_attn_mask_dec, + cache=None, + ) + assert output is not None + assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size) + + def test_local_decoder(self): + args = create_args() + device = torch.device("cuda") + + local_decoder = create_local_decoder(args).to(device) + + test_files = { + "dec_embeds": "dec_embeds.pt", + "decoder_tokens": "local_decoder_tokens.pt", + "patch_embeds": "decoder_patch_embeds.pt", + } + + tensors = { + name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) + for name, filename in test_files.items() + } + + output, cache_decoder = local_decoder( + embeds=tensors["dec_embeds"], + patch_embeds=tensors["patch_embeds"], + tokens=tensors["decoder_tokens"], + cross_mask=None, + cache=None, + ) + assert output is not None + expected_shape = ( + tensors["decoder_tokens"].shape[0], + tensors["decoder_tokens"].shape[1], + args.vocab_size, + ) + assert output.shape == expected_shape + assert cache_decoder is None + + def test_global_transformer(self): + args = create_args() + device = torch.device("cuda") + global_transformer = create_global_transformer(args).to(device) + + test_files = { + "global_embeds": "global_embeds.pt", + "global_tokens": "global_tokens.pt", + } + tensors = { + name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) + for name, filename in test_files.items() + } + h, cache = global_transformer( + embeds=tensors["global_embeds"], tokens=tensors["global_tokens"] + ) + h is not None + assert h.shape == (2, 256, 512) + assert cache is None + + def test_blt_transformer_init(self): + args = create_args() + model = ByteLatentTransformer(args) + assert model is not None + + @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) + def test_blt_transformer_forward(self, attn_type): + args = create_args() + args = args.model_copy(update=dict(efficient_attn=attn_type)) + model = ByteLatentTransformer(args) + model = model.cuda() + batch = fake_batch() + x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) + + output = model( + tokens=x, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + ) + assert output is not None + expected_shape = ( + x.shape[0], + x.shape[1], + args.vocab_size, + ) + assert output.shape == expected_shape + + def test_blt_transformer_cross_attn_forward(self): + args = create_args(cross_attention=True) + model = ByteLatentTransformer(args) + model = model.cuda() + batch = fake_batch() + x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) + + output = model( + tokens=x, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + ) + assert output is not None + expected_shape = ( + x.shape[0], + x.shape[1], + args.vocab_size, + ) + assert output.shape == expected_shape + + def test_cross_attention_rand(self): + x = torch.randn(2, 256, 512, device="cuda") + kv = torch.randn(2, 256, 512, device="cuda") + cross_attention = CrossAttention( + dim=512, + head_dim=64, + n_heads=8, + n_kv_heads=4, + norm_eps=1e-6, + ).to("cuda") + mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + output = cross_attention(x, kv, mask) + assert output is not None + assert output.shape == (2, 256, 512) + + def test_ngram_embeddings(self): + ngram_to_size = { + 2: 38396, + 3: 50000, + 4: 50000, + 5: 50000, + 6: 50000, + 7: 50000, + 8: 50000, + } + batch = fake_batch() + ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size) + ngram_ids = ngram_processor.encode_token_ngrams(batch.x) + ngram_ids = np.stack(ngram_ids, axis=0) + batch = replace(batch, ngram_ids=ngram_ids) + args = create_args(cross_attention=True) + args = args.model_copy( + update=dict( + encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000", + encoder_enable_byte_ngrams=True, + ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes, + ) + ) + model = ByteLatentTransformer(args) + model = model.cuda() + x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) + + output = model( + tokens=x, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + ) + assert output is not None + expected_shape = ( + x.shape[0], + x.shape[1], + args.vocab_size, + ) + assert output.shape == expected_shape + + def test_loss_backward(self): + args = create_args() + args = args.model_copy(update=dict(efficient_attn="sdpa")) + batch = fake_batch() + model = ByteLatentTransformer(args) + steps = 10 + optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps) + model = model.cuda() + x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) + + initial_loss = None + final_loss = None + for step in range(steps): + output = model( + tokens=x, + patch_lengths=patch_lengths, + ngram_ids=ngram_ids, + ) + loss, _ = compute_loss(output, y, mask, 1.0) + if step == 0: + initial_loss = loss.item() + if step == steps - 1: + final_loss = loss.item() + prev_loss = loss.item() + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + assert ( + final_loss < initial_loss + ), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}" diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py new file mode 100644 index 0000000..3acc42d --- /dev/null +++ b/bytelatent/test_entropy_model.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os + +import torch + +from bytelatent.constants import BLT_DATA +from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState +from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator +from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum, entropy +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + +ENTROPY_MODEL = "transformer_100m" +ARROW_TEST_DATA = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") + + +def test_entropy_model(): + initial_state = ArrowFileIteratorState( + file_path=None, + num_workers=1, + worker_id=0, + preprocess_dir=None, + entropy_model_name=ENTROPY_MODEL, + dataset_files=[ARROW_TEST_DATA], + row_num=0, + arrow_batch_size=100, + ) + arrow_file = initial_state.build() + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + entropy_model = load_entropy_model( + BLT_DATA / "checkpoint_0100000_consolidated", + os.path.join( + BLT_DATA, + "entropy_model.pth", + ), + ) + preprocess_iter = PreprocessIterator( + arrow_file, + tokenizer_args=tokenizer_args, + patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.entropy), + add_patches=False, + ) + for example in preprocess_iter.create_iter(): + tokens = torch.tensor(example.tokens).unsqueeze(0) + expected_entropies = torch.tensor(example.entropies).unsqueeze(0) + preds = entropy_model(tokens) + pred_entropies = entropy(preds) + assert pred_entropies.shape == expected_entropies.shape + assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5) + break diff --git a/bytelatent/tokenizers/__init__.py b/bytelatent/tokenizers/__init__.py new file mode 100644 index 0000000..71ca4b1 --- /dev/null +++ b/bytelatent/tokenizers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/bytelatent/tokenizers/abstract_tokenizer.py b/bytelatent/tokenizers/abstract_tokenizer.py new file mode 100644 index 0000000..fff26c3 --- /dev/null +++ b/bytelatent/tokenizers/abstract_tokenizer.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import abc + + +class Tokenizer(abc.ABC): + @abc.abstractmethod + def encode(self, text: str, add_bos: bool, add_eos: bool): + pass + + @abc.abstractmethod + def decode(self, tokens: list[int]): + pass + + @abc.abstractmethod + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + """Return the offsets of the tokens in the original text. Only used for evaluation.""" + pass diff --git a/bytelatent/tokenizers/blt_tokenizer.py b/bytelatent/tokenizers/blt_tokenizer.py new file mode 100644 index 0000000..a3462e1 --- /dev/null +++ b/bytelatent/tokenizers/blt_tokenizer.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import re + +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.tokenizers.constants import ( + BOE_ID, + BOS_ID, + BPE_ID, + BYTE_UNITS, + EOS_ID, + OFFSET, + PAD_ID, +) +from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer + + +def convert_to_bytes(s): + # check if the output is a bytes like object of the format <0x00> + if re.match(r"<0x[0-9a-fA-F]+>", s): + return bytes.fromhex(s[3:-1]) + else: + return bytes(s, "utf-8", errors="ignore") + + +def text2bytes_bpe_delims( + text: str, + *, + bpe_tokenizer, + bpe_id: int, + offsetting_special_char: int, + add_bos: bool, + add_eos: bool, +): + cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) + # merge the leading space tokens + leading_space_tokens = [] + other_bpe_tokens = [] + leading = True + for token in cur_bpe: + bpe_str = bpe_tokenizer.sp_model.id_to_piece(token) + if leading and all(c == "▁" for c in bpe_str): + leading_space_tokens.append(bpe_str) + else: + leading = False + other_bpe_tokens.append(bpe_str) + cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens + + # Remove the '▁' characters + bpe_strs = [] + for i, bpe_str in enumerate(cur_bpe_strs): + if ( + len(bpe_strs) <= 1 + and all([c == " " for s in bpe_strs for c in s]) + and not all(c == "▁" for c in bpe_str) + ): + # Remove leading space for first non space token. + bpe_str = bpe_str.replace("▁", "") + elif i == 0 and all(c == "▁" for c in bpe_str): + bpe_str = " " * (len(text) - len(text.lstrip(" "))) + else: + bpe_str = bpe_str.replace("▁", " ") + if len(bpe_str) > 0: + bpe_strs.append(bpe_str) + ex_seq = [] + # Convert bpe tokens to bytes + for s in bpe_strs: + byte_chunk = convert_to_bytes(s) + proc_chunk = [int(unit) for unit in byte_chunk] + ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk) + + return ex_seq + + +class BltTokenizer(Tokenizer): + def __init__( + self, + *, + vocab_size_unit_1: int = BYTE_UNITS, + bpe_delim: bool = False, + bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", + add_bos: bool = True, + add_eos: bool = True, + ): + self.add_bos = add_bos + self.add_eos = add_eos + self.vocab_size_unit_1 = vocab_size_unit_1 + self.boe_id = BOE_ID + self.bos_id = BOS_ID + self.eos_id = EOS_ID + self.pad_id = PAD_ID + self.bpe_id = BPE_ID + self.bpe_tokenizer_path = bpe_tokenizer_path + if bpe_delim: + self.bpe_tokenizer = SentencePieceTokenizer( + model_path=self.bpe_tokenizer_path + ) + else: + self.bpe_tokenizer = None + self.bpe_delim = bpe_delim + self.offsetting_special_char = OFFSET + self.vocab_size_unit_1 = vocab_size_unit_1 + self.n_words = vocab_size_unit_1 + self.offsetting_special_char + + def encode( + self, text: str, add_bos: bool | None = None, add_eos: bool | None = None + ): + if add_bos is None: + add_bos = self.add_bos + if add_eos is None: + add_eos = self.add_eos + + if self.bpe_delim: + tokens = text2bytes_bpe_delims( + text, + bpe_tokenizer=self.bpe_tokenizer, + bpe_id=self.bpe_id, + offsetting_special_char=self.offsetting_special_char, + add_bos=False, + add_eos=False, + ) + else: + tokens = bytes(text, encoding="utf-8", errors="ignore") + + # Offsetting + tokens = [int(unit) + self.offsetting_special_char for unit in tokens] + + if add_bos: + tokens.insert(0, self.bos_id) + if add_eos: + tokens.append(self.eos_id) + + return tokens + + def decode(self, tokens: list[int], cut_at_eos: bool = False): + if cut_at_eos: + for k, t in enumerate(tokens): + if t == self.eos_id: + tokens = tokens[: k + 1] + break + return bytes( + [ + tok - self.offsetting_special_char + for tok in tokens + if tok - self.offsetting_special_char >= 0 + ] + ).decode("utf-8", errors="ignore") + + def get_token_offsets(self, text: str, tokens: list[int] | None = None): + # TODO: Figure out what this does + raise NotImplementedError() diff --git a/bytelatent/tokenizers/build_tokenizer.py b/bytelatent/tokenizers/build_tokenizer.py new file mode 100644 index 0000000..8aa434d --- /dev/null +++ b/bytelatent/tokenizers/build_tokenizer.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +from typing import Any + +from pydantic import BaseModel + +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer +from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer +from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer + +try: + from sentencepiece import SentencePieceProcessor + + has_sp = True +except ImportError: + has_sp = False + +try: + import tiktoken + from tiktoken.load import load_tiktoken_bpe + + has_tiktoken = True +except ImportError: + has_tiktoken = False + +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer + +logger = logging.getLogger(__name__) + + +class MockTokenizer(Tokenizer): + n_words: int = 256 + + def encode(self, text: str, add_bos: bool, add_eos: bool): + return text + + def decode(self, tokens): + raise NotImplementedError() + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str]]: + raise NotImplementedError() + + +class TokenizerArgs(BaseModel): + name: str = "bytes" + init_kwargs: dict[str, Any] | None = None + + def build(self) -> Tokenizer: + if self.init_kwargs is None: + init_kwargs = {} + else: + init_kwargs = self.init_kwargs + if self.name == "blt": + return BltTokenizer(**init_kwargs) + elif self.name == "bytes": + return ByteTokenizer(**init_kwargs) + elif self.name == "mock": + return MockTokenizer(**init_kwargs) + elif self.name == "sp": + assert has_sp, "sentencepiece not installed" + return SentencePieceTokenizer(**init_kwargs) + elif self.name == "tiktoken": + assert has_tiktoken, "tiktoken not installed" + return TikTokenTokenizer(**init_kwargs) + else: + raise NotImplementedError(f"{self.name} tokenizer type is not implemented") diff --git a/bytelatent/tokenizers/byte_tokenizer.py b/bytelatent/tokenizers/byte_tokenizer.py new file mode 100644 index 0000000..f85f4f7 --- /dev/null +++ b/bytelatent/tokenizers/byte_tokenizer.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer + + +class ByteTokenizer(Tokenizer): + def __init__(self): + self.bos_id = 256 + self.eos_id = 257 + self.n_words = 258 + + def encode(self, s: str, add_bos: bool = False, add_eos: bool = False): + tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos + return tokens + + def decode(self, tokens: list[int]): + byte_tokens = bytes([t for t in tokens if t < 256]) + return byte_tokens.decode("utf-8", errors="backslashreplace") + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + if tokens is None: + tokens = self.encode(text) + + decoded_chars, offsets = [], [] + byte_pos = 0 + for token in tokens: + if token < 256: + char = bytes([token]).decode("utf-8", errors="ignore") + if char: + decoded_chars.append(char) + offsets.append(byte_pos) + byte_pos += len(char.encode("utf-8")) + + return decoded_chars, offsets diff --git a/bytelatent/tokenizers/constants.py b/bytelatent/tokenizers/constants.py new file mode 100644 index 0000000..774119e --- /dev/null +++ b/bytelatent/tokenizers/constants.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + + +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 + +BYTE_UNITS: int = 256 diff --git a/bytelatent/tokenizers/sentence_piece_tokenizer.py b/bytelatent/tokenizers/sentence_piece_tokenizer.py new file mode 100644 index 0000000..faeb997 --- /dev/null +++ b/bytelatent/tokenizers/sentence_piece_tokenizer.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +import os + +try: + from sentencepiece import SentencePieceProcessor + + has_sp = True +except ImportError: + has_sp = False + +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer + +logger = logging.getLogger(__name__) + + +class SentencePieceTokenizer(Tokenizer): + def __init__( + self, model_path: str, add_bos: bool = True, add_eos: bool = True + ) -> None: + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + self.add_bos = add_bos + self.add_eos = add_eos + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None): + if add_bos is None: + add_bos = self.add_bos + + if add_eos is None: + add_eos = self.add_eos + assert type(s) is str + tokens = ( + [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos + ) + return tokens + + def decode(self, tokens: list[int]): + return self.sp_model.decode(tokens) + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + pieces = self.sp_model.encode_as_immutable_proto(text).pieces + substrs = [p.surface for p in pieces] + offsets = [p.begin for p in pieces] + return substrs, offsets diff --git a/bytelatent/tokenizers/test_blt_tokenizer.py b/bytelatent/tokenizers/test_blt_tokenizer.py new file mode 100644 index 0000000..50308da --- /dev/null +++ b/bytelatent/tokenizers/test_blt_tokenizer.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import json + +from bytelatent.constants import BLT_DATA +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def test_tokenizer_bytes(): + with open("fixtures/tokenizer_data.json") as f: + data = json.load(f) + + examples: list[str] = data["texts"] + examples_tokens: list[list[int]] = data["tokens"] + + tokenizer = BltTokenizer(bpe_delim=False) + for i in range(len(examples)): + assert tokenizer.encode(examples[i]) == examples_tokens[i] + + +def test_tokenizer_bpe(): + with open("fixtures/tokenizer_data_bpe_delim.json") as f: + data = json.load(f) + + examples: list[str] = data["texts"] + examples_tokens: list[list[int]] = data["tokens"] + + tokenizer = BltTokenizer(bpe_delim=True) + for i in range(len(examples)): + assert tokenizer.encode(examples[i]) == examples_tokens[i] + + +def test_build_tokenizer_from_args(): + tokenizer_args = TokenizerArgs( + name="blt", + init_kwargs={ + "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" + }, + ) + tokenizer = tokenizer_args.build() + assert tokenizer.encode("test text") is not None diff --git a/bytelatent/tokenizers/tiktoken_tokenizer.py b/bytelatent/tokenizers/tiktoken_tokenizer.py new file mode 100644 index 0000000..f498bf2 --- /dev/null +++ b/bytelatent/tokenizers/tiktoken_tokenizer.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +from copy import copy +from pathlib import Path + +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer + +try: + import tiktoken + from tiktoken.load import load_tiktoken_bpe + + has_tiktoken = True +except ImportError: + has_tiktoken = False +DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +DEFAULT_TIKTOKEN_SPECIAL_TOKENS = { + "<|begin_of_text|>": 0, + "<|end_of_text|>": 1, + "<|fim_prefix|>": 2, + "<|fim_middle|>": 3, + "<|fim_end_fill|>": 253, + "<|fim_pad|>": 254, + "<|fim_suffix|>": 255, +} +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +logger = logging.getLogger(__name__) + + +class TikTokenTokenizer(Tokenizer): + def __init__(self, model_path: str) -> None: + mergeable_ranks = load_tiktoken_bpe(model_path) + all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS) + missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values()) + for id in missing_ids: + all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id + for name in all_special_tokens_with_ids: + all_special_tokens_with_ids[name] += len(mergeable_ranks) + + self.tkt_model = tiktoken.core.Encoding( + name=Path(model_path).stem, + pat_str=DEFAULT_TIKTOKEN_PATTERN, + mergeable_ranks=mergeable_ranks, + special_tokens=all_special_tokens_with_ids, + ) + + self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>") + self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>") + + self.n_words: int = self.tkt_model.n_vocab + + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + + def encode(self, s: str, add_bos: bool, add_eos: bool): + assert isinstance(s, str) + + subs = [] + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): + subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) + return ( + [self.bos_id] * add_bos + + sum(self.tkt_model.encode_ordinary_batch(subs), start=[]) + + [self.eos_id] * add_eos + ) + + def decode(self, tokens: list[int]): + return self.tkt_model.decode(tokens) + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + if tokens is not None: + token_bytes = self.tkt_model.decode_tokens_bytes(tokens) + else: + token_bytes = self.tkt_model.decode_tokens_bytes( + self.tkt_model.encode(text, allowed_special="all") + ) + + text_len, offsets = 0, [] + for token in token_bytes: + offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0))) + text_len += sum(1 for c in token if not 0x80 <= c < 0xC0) + substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])] + return substrs, offsets diff --git a/bytelatent/train.py b/bytelatent/train.py new file mode 100644 index 0000000..6cb13b9 --- /dev/null +++ b/bytelatent/train.py @@ -0,0 +1,651 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import gc +import logging +import os +import sys +from contextlib import ExitStack +from copy import deepcopy +from dataclasses import asdict, dataclass +from pathlib import Path +from timeit import default_timer as timer +from typing import Any, Dict, Type, TypeVar + +import torch +import torch.distributed +import torch.nn.functional +import torch.nn.functional as F +import wandb +import xformers.profiler +from omegaconf import OmegaConf +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim import lr_scheduler + +from bytelatent.args import TrainArgs +from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.data_types import DataLoaderState +from bytelatent.distributed import ( + check_model_value_range, + clean_env, + dist_mean_dict, + get_device_mesh, + get_is_master, + get_world_size, + init_signal_handler, + parallelize_model, + requeue_slurm_job, + setup_env, + setup_torch_distributed, +) +from bytelatent.logger import init_logger +from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.optim import build_optimizer +from bytelatent.probe import AutoProbeD +from bytelatent.profiling import maybe_run_profiler +from bytelatent.stool import StoolArgs, launch_job +from bytelatent.transformer import ( + build_fsdp_grouping_plan, + get_no_recompute_ops, + get_num_flop_per_token, + tp_parallelize, +) + +logger = logging.getLogger() + +T = TypeVar("T") + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: + """ + Converts a dictionary to a dataclass instance, recursively for nested structures. + """ + base = OmegaConf.structured(cls()) + OmegaConf.set_struct(base, strict) + override = OmegaConf.create(data) + return OmegaConf.to_object(OmegaConf.merge(base, override)) + + +@dataclass +class TrainState(Stateful): + step: int # Nb of steps taken by the optimizer + acc_step: int # Nb of accumulation steps done since last optimizer step + scheduler: lr_scheduler.LambdaLR + data_loader_state: DataLoaderState + scale: float = 1.0 + + def state_dict(self) -> Dict[str, Any]: + return { + "step": self.step, + "acc_step": self.acc_step, + "data_loader_state": self.data_loader_state.dict(), + "scheduler": self.scheduler.state_dict(), + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.acc_step = state_dict["acc_step"] + self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) + self.scheduler.load_state_dict(state_dict["scheduler"]) + + +def validate_train_args(args: TrainArgs, output_size: int): + if args.model.vocab_size < 0: + logger.info(f"Setting model output size to {args.model.vocab_size}") + args.model.vocab_size = output_size + + assert args.dump_dir, "Dump dir not set" + + if args.checkpoint.path is None: + logger.info(f"Setting checkpoint path to {args.checkpoint.path}") + args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + + for source in args.data.sources: + data_path = os.path.join(args.data.root_dir, source) + assert os.path.exists(data_path), f"{data_path} doesn't exist" + + if ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + != get_world_size() + ): + assert get_world_size() % args.distributed.dp_shard == 0 + args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + + assert args.distributed.dp_replicate % args.distributed.tp_size == 0 + args.distributed.dp_replicate = ( + args.distributed.dp_replicate // args.distributed.tp_size + ) + + logger.warning( + f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" + ) + assert ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + == get_world_size() + ) + + if args.distributed.fsdp_type == "no_shard": + assert ( + args.distributed.dp_shard == 1 + and args.distributed.dp_replicate == get_world_size() + ) + + args.model.max_seqlen = args.data.seq_len + + if args.distributed.tp_size == 1: + logger.warning( + "Tensor parallelism has not been tested for a while, use at your own risk" + ) + + assert ( + args.probe_freq != args.profiling.mem_steps + ), "Don't profile during probe step" + assert ( + args.probe_freq != args.profiling.profile_steps + ), "Don't profile during probe step" + if args.logging.wandb is not None: + args.logging.wandb.name = args.name + + if args.probe_freq is not None: + assert ( + args.distributed.tp_size == 1 + ), "Probing not supported with tensor parallelism" + assert ( + args.distributed.selective_activation_checkpointing is False + ), "Probing not supported with selective activation checkpointing" + + +preemption_flag = dict(flag=False) + + +def set_preemption_flag(signum, frame): + logger.warning("Signal handler called with signal " + str(signum)) + logger.warning("Preemption ! checkpointing asap and exiting.") + preemption_flag["flag"] = True + + +def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): + test = train_state.step % freq == 0 + if acc_step is not None: + test = test and (train_state.acc_step == acc_step) + elif acc_freq is not None: + test = test and ((train_state.acc_step % acc_freq) == 0) + return test + + +def compute_loss(p, y, mask, scale): + tok_loss = scale * F.cross_entropy( + p.flatten(0, 1), y.flatten(0, 1), reduction="none" + ) + if mask is None: + loss = tok_loss.mean() + else: + mask = mask.flatten(0, 1) + tok_loss = tok_loss * mask + loss = tok_loss.sum() / (mask.sum() + 1e-6) + return loss, tok_loss + + +def train(args: TrainArgs): + with ExitStack() as context_stack: + tokenizer = args.data.tokenizer_args.build() + validate_train_args( + args, + tokenizer.n_words, + ) + if get_is_master(): + os.makedirs(args.dump_dir, exist_ok=True) + args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml") + init_logger(Path(args.dump_dir) / "train.log") + init_signal_handler(set_preemption_flag) # For handling preemption signals. + setup_env(args.env) + setup_torch_distributed(args.distributed) + world_mesh = get_device_mesh(args.distributed) + logger.info(f"Starting job: {args.name}") + + # build dataloader + # need dp world size and rank + dp_mesh = world_mesh["dp_replicate"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + if args.distributed.dp_shard > 1: + dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() + dp_degree *= world_mesh["dp_shard"].size() + + logger.info(f"Running on dp rank : {dp_rank}") + logger.info(f"Running on dp size : {dp_degree}") + + torch.manual_seed(args.seed) + logger.info("Building model") + + # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory + with torch.device("meta"): + model = ByteLatentTransformer(args.model) + logger.info("Model is built !") + + model_param_count = get_num_params(model) + + model = parallelize_model( + model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + tp_parallelize=tp_parallelize, + no_recompute_ops=get_no_recompute_ops(), + ) + + # Once we shard the model on different gpus we can actually initialize the model + # First we create empty tensors of the correct shapes + model = model.to_empty(device="cuda") + # Then we init the model. Please make sure this function initializes *ALL* parameters + # and buffers, otherwise you will have random values in the unitialized tensors + # which will silently fail (give nan gradients for example) + + if args.checkpoint.init_ckpt_path: + logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + load_from_checkpoint( + args.checkpoint.init_ckpt_path, model, model_key="model" + ) # Put model_key="" if its directly the model checkpoint + model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded + else: + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(args.model.seed) + model.init_weights() + check_model_value_range(model, range=10.0, std=1.0) + + # log model size + + logger.info(f"Model size: {model_param_count:,} total parameters") + + gpu_memory_monitor = GPUMemoryMonitor("cuda") + logger.info( + f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " + f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" + ) + logger.info(f"GPU memory usage: {gpu_memory_monitor}") + + # build optimizer after apply parallelisms to the model + optimizer, scheduler = build_optimizer(model, args.optim, args.steps) + data_loader = args.data.build_from_rank(dp_rank, dp_degree) + data_loader_state = data_loader.get_state() + + train_state = TrainState( + step=0, + acc_step=0, + data_loader_state=data_loader_state, + scheduler=scheduler, + scale=1.0, + ) + + checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) + checkpoint.load(model, optimizer, train_state, world_mesh) + # Either load from latest checkpoint or start from scratch + if args.probe_freq is not None: + if get_is_master(): + os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + torch.distributed.barrier() + probe = AutoProbeD( + model, + ( + Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + if (dp_rank % 128 == 0) + else None + ), + ) + probe_mod = model._orig_mod if args.distributed.compile else model + + gc.disable() + + # train loop + model.train() + metric_logger = context_stack.enter_context( + MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + ) + data_loader = train_state.data_loader_state.build() + batch_iterator = data_loader.create_iter() + + torch_profiler = context_stack.enter_context( + maybe_run_profiler(args.dump_dir, model, args.profiling) + ) + + nwords_since_last_log = 0 + time_last_log = timer() + gc.collect() + while train_state.step < args.steps: + # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 + train_state.acc_step += 1 + train_state.acc_step = train_state.acc_step % args.grad_acc_steps + + # get batch + curr_lr = float(optimizer.param_groups[0]["lr"]) + data_load_start = timer() + batch = next(batch_iterator) + batch_x = torch.from_numpy( + batch.x, + ).cuda() + batch_y = torch.from_numpy(batch.y).cuda() + batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() + mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + + if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None: + raise ValueError( + "Cannot enable byte ngrams and have batch.ngram_ids be None" + ) + ngram_ids = ( + None + if batch.ngram_ids is None + else torch.from_numpy(batch.ngram_ids).cuda() + ) + + if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): + logger.info("garbage collection") + # we do garbage collection manually otherwise different processes + # run the GC at different times so they slow down the whole pipeline + gc.collect() + + data_load_time = round(timer() - data_load_start, 4) + nwords_since_last_log += batch_x.numel() + + bsz, seqlen = batch_y.shape + + # forward + start_timer = torch.cuda.Event(enable_timing=True) + end_timer = torch.cuda.Event(enable_timing=True) + start_timer.record() + + # This is an automatic probe that will compute statistics + # of all linears' inputs, weights and outputs + # along with attention logits and entropy + # both in forward and backward pass + tok_loss = None + if (args.probe_freq is not None) and every_n_steps( + train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps + ): + # Here we do a fake forward and backward pass on a smaller + # batch size to avoid OOM + # This assumes the model has no stateful layers (batch norm..) + assert ( + next(probe_mod.parameters()).grad is None + ), "Can't probe model if grads are not reset" + + with probe: + probe.metadata = { + "it": train_state.step, + "global_step": train_state.step, + "loop": "lingua", + } + # Non compiled model uses roughly 2x memory in our exps + # So we divide bsz by 2 or seqlen by 2 + probe_bsz = max(1, bsz // 2) + probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2) + probe_loss = probe_mod( + batch_x[:probe_bsz, :probe_seq], + batch_y[:probe_bsz, :probe_seq], + ) + probe_loss.backward() + # We zero grads to cancel this fake step + optimizer.zero_grad() + + assert ( + next(probe_mod.parameters()).grad is None + ), "Probe model shouldn't have grads at this point" + + pred = model( + batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids + ) + + loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = loss / args.grad_acc_steps + + # backward on scaled loss to create scaled gradients + loss.backward() + # For logging we undo that scaling + loss = loss.detach() * args.grad_acc_steps + + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + + grad_norm = ( + grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm + ).item() + + # optimizer step + if train_state.acc_step == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + train_state.step += 1 + + # updates the scale for next iteration + # training iteration complete + end_timer.record() + + torch.cuda.synchronize() + + curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) + + # if profiler is active + if torch_profiler: + xformers.profiler.step() + + # log metrics + if every_n_steps( + train_state, + args.logging.freq, + acc_step=None if args.logging.acc_freq else 0, + acc_freq=args.logging.acc_freq, + ): + time_delta = timer() - time_last_log + wps = nwords_since_last_log / (time_delta * args.distributed.tp_size) + + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + + total_acc_steps = ( + args.grad_acc_steps * train_state.step + train_state.acc_step + ) + tokens_per_gpu = ( + total_acc_steps * args.data.batch_size * args.data.seq_len + ) + total_tokens = dp_degree * tokens_per_gpu + # This is an estimate and the correct values may change + # if you change the architecture + # Use xformer's analyze profile trace to get actual measurement + FLOPS = ( + get_num_flop_per_token( + model_param_count - args.model.vocab_size * args.model.dim, + args.model.n_layers, + args.model.dim, + args.data.seq_len, + ) + * wps + ) + + metrics = flatten_dict( + { + "global_step": train_state.step, + "acc_step": train_state.acc_step, + "speed": { + "wps": wps, + "FLOPS": FLOPS, + "curr_iter_time": curr_iter_time, + "data_load_time": data_load_time, + }, + "optim": { + "grad_norm": grad_norm, + "lr": curr_lr, + "total_tokens": total_tokens, + }, + "memory": gpu_mem_stats._asdict(), + }, + sep="/", + ) + + to_sync = {} + to_sync["loss/out"] = loss.item() + metrics.update(dist_mean_dict(to_sync)) + + if get_is_master(): + metric_logger.log(metrics) + + gpu_memory_monitor.reset_peak_stats() + nwords_since_last_log = 0 + time_last_log = timer() + logger.info( + f"step: {train_state.step}" + f" acc: {train_state.acc_step}" + f" loss: {round(loss.item(),4):>7}" + f" grad: {grad_norm:.2e}" + f" flops: {FLOPS:.2e}" + f" wps: {wps:.2e}" + f" iter: {curr_iter_time:>7}" + f" data: {data_load_time:>5}" + f" lr: {curr_lr:.2e}" + f" mem: {gpu_mem_stats.max_active_pct:.0f}%" + f" pow: {gpu_mem_stats.power_draw/1000} W" + ) + + saved = False + if every_n_steps( + train_state, args.checkpoint.dump.every, acc_step=0 + ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + saved = checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + + if args.eval is not None and every_n_steps( + train_state, args.checkpoint.eval.every, acc_step=0 + ): + from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + + eval_args = dataclass_from_dict(EvalArgs, args.eval) + + eval_args.global_step = train_state.step + eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) + eval_args.dump_dir = str( + os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), + ) + ) + eval_args.metric_log_dir = args.dump_dir + if args.async_eval_gpus is None: + launch_eval(eval_args) + elif get_is_master(): + if wandb.run is not None and args.logging.wandb is not None: + eval_args.wandb = deepcopy(args.logging.wandb) + assert args.async_eval_gpus > 0 + logger.info(f"Launching evals on {args.async_eval_gpus} gpus") + with clean_env(): + launch_job( + StoolArgs( + asdict(eval_args), + script="apps.main.eval", + copy_code=False, + nodes=args.async_eval_gpus // 8, + qos="lowest", + ) + ) + + if preemption_flag["flag"]: + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + requeue_slurm_job() + sys.exit(0) + + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + gc.collect() + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate TrainArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call train.py with train.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in TrainArgs dataclass. + """ + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.create(TrainArgs().model_dump()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + train_args = TrainArgs.model_validate(cfg) + train(train_args) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py new file mode 100644 index 0000000..432f7df --- /dev/null +++ b/bytelatent/transformer.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, +) +from torch.nn.attention.flex_attention import BlockMask, create_block_mask +from xformers.ops import AttentionBias, fmha + +from bytelatent.base_transformer import ( + BaseTransformer, + BaseTransformerArgs, + RMSNorm, + cross_entropy, +) + + +def create_causal_mask(seqlen, attn_impl, sliding_window): + if sliding_window is not None and attn_impl == "xformers": + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) + elif attn_impl == "xformers": + return fmha.attn_bias.LowerTriangularMask() + elif attn_impl == "sdpa": + return "causal" + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) + + +def attention_flops_per_token(n_layers, seq_len, dim, causal): + # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 + return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1)) + + +def get_num_flop_per_token( + num_non_embed_params: int, n_layers: int, dim: int, seq_len: int +) -> int: + return 6 * num_non_embed_params + attention_flops_per_token( + n_layers, seq_len, dim, True + ) + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +class LMTransformerArgs(BaseTransformerArgs): + seed: int = 42 + + vocab_size: int = -1 + weight_tying: bool = False + + sliding_window: int | None = None + + +class LMTransformer(BaseTransformer): + def __init__(self, args: LMTransformerArgs): + super().__init__(args) + self.weight_tying = args.weight_tying + self.sliding_window = args.sliding_window + + assert args.vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + self.output = nn.Linear( + args.dim, + args.vocab_size, + bias=False, + ) + + if args.weight_tying: + self.output.weight = self.embeddings.tok_embeddings.weight + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, + attn_impl: str = "sdpa", + ): + bsz, seqlen = token_values.shape + + h = self.tok_embeddings(token_values) + + mask = ( + mask + if mask is not None + else create_causal_mask(seqlen, attn_impl, self.sliding_window) + ) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + + logits = self.output(self.norm(h)) + if target is not None: + return cross_entropy(logits, target) + else: + return logits + + def reset_parameters(self, init_std=None): + # Either use fixed base std or sqrt model dim + super().reset_parameters() + init_std = init_std or (self.dim ** (-0.5)) + self.norm.reset_parameters() + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + +# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops) +def get_no_recompute_ops(): + return None + + +# Optional and only used for fully shard options (fsdp) is choose. Highly recommanded for large models +def build_fsdp_grouping_plan(model_args: LMTransformerArgs): + group_plan: Tuple[int, bool] = [] + + # Grouping and output seperately + group_plan.append(("tok_embeddings", False)) + + # Grouping by layers + for i in range(model_args.n_layers): + group_plan.append((f"layers.{i}", False)) + + group_plan.append(("output", True)) + + return group_plan + + +# Optional and only used for model/tensor parallelism when tp_size > 1 +def tp_parallelize(model, tp_mesh, model_args: LMTransformerArgs, distributed_args): + assert model_args.dim % distributed_args.tp_size == 0 + assert model_args.vocab_size % distributed_args.tp_size == 0 + assert model_args.n_heads % distributed_args.tp_size == 0 + assert (model_args.n_kv_heads or 0) % distributed_args.tp_size == 0 + assert model_args.n_heads % (model_args.n_kv_heads or 1) == 0 + + # Embedding layer tp + main_plan = {} + main_plan["tok_embeddings"] = ColwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ) + main_plan["norm"] = SequenceParallel() + main_plan["output"] = ColwiseParallel( + input_layouts=Shard(1), output_layouts=Replicate() + ) + + parallelize_module( + model, + tp_mesh, + main_plan, + ) + + # Attention layers tp + for layer in model.layers: + layer_plan = {} + + layer_plan["attention"] = PrepareModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ) + layer_plan["attention_norm"] = SequenceParallel() + layer_plan["attention.wq"] = ColwiseParallel() + layer_plan["attention.wk"] = ColwiseParallel() + layer_plan["attention.wv"] = ColwiseParallel() + layer_plan["attention.wo"] = RowwiseParallel(output_layouts=Shard(1)) + + # Feedforward layers tp + layer_plan["feed_forward"] = PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ) + layer_plan["ffn_norm"] = SequenceParallel() + layer_plan["feed_forward.w1"] = ColwiseParallel() + layer_plan["feed_forward.w3"] = ColwiseParallel() + layer_plan["feed_forward.w2"] = RowwiseParallel(output_layouts=Shard(1)) + + parallelize_module( + layer, + tp_mesh, + layer_plan, + ) + + # Adjusting the number of heads and kv heads according to the tp size + attn_layer = layer.attention + attn_layer.n_heads = attn_layer.n_heads // distributed_args.tp_size + attn_layer.n_kv_heads = attn_layer.n_kv_heads // distributed_args.tp_size diff --git a/dev/lint.sh b/dev/lint.sh new file mode 100644 index 0000000..cde5ce1 --- /dev/null +++ b/dev/lint.sh @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +isort . +black . diff --git a/fixtures/tokenizer_data.json b/fixtures/tokenizer_data.json new file mode 100644 index 0000000..1aeb9e2 --- /dev/null +++ b/fixtures/tokenizer_data.json @@ -0,0 +1 @@ +{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 80, 105, 120, 43, 119, 36, 103, 108, 105, 103, 111, 36, 109, 106, 36, 120, 108, 105, 119, 105, 36, 120, 115, 111, 105, 114, 109, 126, 105, 118, 119, 36, 113, 101, 120, 103, 108, 37, 2]]} \ No newline at end of file diff --git a/fixtures/tokenizer_data_bpe_delim.json b/fixtures/tokenizer_data_bpe_delim.json new file mode 100644 index 0000000..c726f6e --- /dev/null +++ b/fixtures/tokenizer_data_bpe_delim.json @@ -0,0 +1 @@ +{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 3, 80, 105, 120, 3, 43, 3, 119, 3, 36, 103, 108, 105, 103, 111, 3, 36, 109, 106, 3, 36, 120, 108, 105, 119, 105, 3, 36, 120, 115, 111, 105, 114, 3, 109, 126, 105, 118, 119, 3, 36, 113, 101, 120, 103, 108, 3, 37, 2]]} \ No newline at end of file diff --git a/fixtures/tokens_with_entropies.json b/fixtures/tokens_with_entropies.json new file mode 100644 index 0000000..e3589d8 --- /dev/null +++ b/fixtures/tokens_with_entropies.json @@ -0,0 +1 @@ +{"position":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":6,"7":7,"8":8,"9":9,"10":10,"11":11,"12":12,"13":13,"14":14,"15":15,"16":16,"17":17,"18":18,"19":19,"20":20,"21":21,"22":22,"23":23,"24":24,"25":25,"26":26,"27":27,"28":28,"29":29,"30":30,"31":31,"32":32,"33":33,"34":34,"35":35,"36":36,"37":37,"38":38,"39":39,"40":40,"41":41,"42":42,"43":43,"44":44,"45":45,"46":46,"47":47,"48":48,"49":49,"50":50,"51":51,"52":52,"53":53,"54":54,"55":55,"56":56,"57":57,"58":58,"59":59,"60":60,"61":61,"62":62,"63":63,"64":64,"65":65,"66":66,"67":67,"68":68,"69":69,"70":70,"71":71,"72":72,"73":73,"74":74,"75":75,"76":76,"77":77,"78":78,"79":79,"80":80},"tokens":{"0":"<","1":"D","2":"a","3":"e","4":"n","5":"e","6":"r","7":"y","8":"s","9":"_","10":"T","11":"a","12":"r","13":"g","14":"a","15":"r","16":"y","17":"e","18":"n","19":"_","20":"i","21":"s","22":"_","23":"i","24":"n","25":"_","26":"G","27":"a","28":"m","29":"e","30":"_","31":"o","32":"f","33":"_","34":"T","35":"h","36":"r","37":"o","38":"n","39":"e","40":"s","41":",","42":"_","43":"a","44":"_","45":"f","46":"a","47":"n","48":"t","49":"a","50":"s","51":"y","52":"_","53":"e","54":"p","55":"i","56":"c","57":"_","58":"b","59":"y","60":"_","61":"G","62":"e","63":"o","64":"r","65":"g","66":"e","67":"_","68":"R","69":".","70":"R","71":".","72":"_","73":"M","74":"a","75":"r","76":"t","77":"i","78":"n","79":".","80":">"},"token_ids":{"0":1,"1":72,"2":101,"3":105,"4":114,"5":105,"6":118,"7":125,"8":119,"9":36,"10":88,"11":101,"12":118,"13":107,"14":101,"15":118,"16":125,"17":105,"18":114,"19":36,"20":109,"21":119,"22":36,"23":109,"24":114,"25":36,"26":75,"27":101,"28":113,"29":105,"30":36,"31":115,"32":106,"33":36,"34":88,"35":108,"36":118,"37":115,"38":114,"39":105,"40":119,"41":48,"42":36,"43":101,"44":36,"45":106,"46":101,"47":114,"48":120,"49":101,"50":119,"51":125,"52":36,"53":105,"54":116,"55":109,"56":103,"57":36,"58":102,"59":125,"60":36,"61":75,"62":105,"63":115,"64":118,"65":107,"66":105,"67":36,"68":86,"69":50,"70":86,"71":50,"72":36,"73":81,"74":101,"75":118,"76":120,"77":109,"78":114,"79":50,"80":2},"entropies":{"0":3.3949158192,"1":2.1656451225,"2":2.3216569424,"3":2.8214058876,"4":1.5249242783,"5":0.0401624143,"6":0.0981037766,"7":0.0544578359,"8":0.3430138826,"9":1.0546212196,"10":0.25252828,"11":0.1494535804,"12":0.0624754503,"13":0.001355894,"14":0.0050173439,"15":0.0052358187,"16":0.0011725067,"17":0.0010307421,"18":1.0241208076,"19":3.6867966652,"20":0.4502205253,"21":0.0484119244,"22":2.2572875023,"23":0.3789347112,"24":1.0042934418,"25":2.9090054035,"26":1.8933598995,"27":1.3859074116,"28":0.3827198744,"29":0.2646365762,"30":1.7742085457,"31":0.0136727821,"32":0.0053820172,"33":0.5485631227,"34":0.2064044327,"35":0.0049266233,"36":0.0005439016,"37":0.0007023578,"38":0.0004170335,"39":0.0054524317,"40":1.1938130856,"41":0.0238215197,"42":3.1279797554,"43":1.3883389235,"44":3.0503094196,"45":1.695879817,"46":1.8551058769,"47":1.4570231438,"48":0.0047810897,"49":0.026396824,"50":0.6633765101,"51":0.3141393065,"52":2.8411159515,"53":1.143143177,"54":0.0520330966,"55":0.3398066461,"56":0.4140175879,"57":2.5563707352,"58":1.3370712996,"59":0.0227173548,"60":3.4447185993,"61":1.8576486111,"62":0.8189754486,"63":0.6776530743,"64":0.0677763447,"65":0.212713033,"66":0.1003480032,"67":0.1746164262,"68":0.4123829603,"69":0.5507118702,"70":0.1047425047,"71":0.0194335245,"72":0.001482119,"73":0.0009310447,"74":0.0002176317,"75":0.0076908777,"76":0.0003866984,"77":0.0008008487,"78":1.2395234108,"79":0.4564163089,"80":0.0000461392},"patch":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":5,"7":5,"8":5,"9":5,"10":5,"11":5,"12":5,"13":5,"14":5,"15":5,"16":5,"17":5,"18":5,"19":5,"20":6,"21":6,"22":6,"23":7,"24":7,"25":7,"26":8,"27":9,"28":10,"29":10,"30":10,"31":11,"32":11,"33":11,"34":11,"35":11,"36":11,"37":11,"38":11,"39":11,"40":11,"41":11,"42":11,"43":12,"44":13,"45":14,"46":15,"47":16,"48":17,"49":17,"50":17,"51":17,"52":17,"53":18,"54":18,"55":18,"56":18,"57":18,"58":19,"59":20,"60":20,"61":21,"62":22,"63":22,"64":22,"65":22,"66":22,"67":22,"68":22,"69":22,"70":22,"71":22,"72":22,"73":22,"74":22,"75":22,"76":22,"77":22,"78":22,"79":22,"80":22},"start":{"0":1,"1":1,"2":1,"3":1,"4":1,"5":1,"6":0,"7":0,"8":0,"9":0,"10":0,"11":0,"12":0,"13":0,"14":0,"15":0,"16":0,"17":0,"18":0,"19":0,"20":1,"21":0,"22":0,"23":1,"24":0,"25":0,"26":1,"27":1,"28":1,"29":0,"30":0,"31":1,"32":0,"33":0,"34":0,"35":0,"36":0,"37":0,"38":0,"39":0,"40":0,"41":0,"42":0,"43":1,"44":1,"45":1,"46":1,"47":1,"48":1,"49":0,"50":0,"51":0,"52":0,"53":1,"54":0,"55":0,"56":0,"57":0,"58":1,"59":1,"60":0,"61":1,"62":1,"63":0,"64":0,"65":0,"66":0,"67":0,"68":0,"69":0,"70":0,"71":0,"72":0,"73":0,"74":0,"75":0,"76":0,"77":0,"78":0,"79":0,"80":0}} \ No newline at end of file diff --git a/plot_data/entropy_figure.json b/plot_data/entropy_figure.json new file mode 100644 index 0000000..79e42d9 --- /dev/null +++ b/plot_data/entropy_figure.json @@ -0,0 +1 @@ +{"text":"Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.","threshold":1.335442066192627,"dataframe_json":"{\"position\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":6,\"7\":7,\"8\":8,\"9\":9,\"10\":10,\"11\":11,\"12\":12,\"13\":13,\"14\":14,\"15\":15,\"16\":16,\"17\":17,\"18\":18,\"19\":19,\"20\":20,\"21\":21,\"22\":22,\"23\":23,\"24\":24,\"25\":25,\"26\":26,\"27\":27,\"28\":28,\"29\":29,\"30\":30,\"31\":31,\"32\":32,\"33\":33,\"34\":34,\"35\":35,\"36\":36,\"37\":37,\"38\":38,\"39\":39,\"40\":40,\"41\":41,\"42\":42,\"43\":43,\"44\":44,\"45\":45,\"46\":46,\"47\":47,\"48\":48,\"49\":49,\"50\":50,\"51\":51,\"52\":52,\"53\":53,\"54\":54,\"55\":55,\"56\":56,\"57\":57,\"58\":58,\"59\":59,\"60\":60,\"61\":61,\"62\":62,\"63\":63,\"64\":64,\"65\":65,\"66\":66,\"67\":67,\"68\":68,\"69\":69,\"70\":70,\"71\":71,\"72\":72,\"73\":73,\"74\":74,\"75\":75,\"76\":76,\"77\":77,\"78\":78,\"79\":79,\"80\":80},\"tokens\":{\"0\":\"<\",\"1\":\"D\",\"2\":\"a\",\"3\":\"e\",\"4\":\"n\",\"5\":\"e\",\"6\":\"r\",\"7\":\"y\",\"8\":\"s\",\"9\":\"_\",\"10\":\"T\",\"11\":\"a\",\"12\":\"r\",\"13\":\"g\",\"14\":\"a\",\"15\":\"r\",\"16\":\"y\",\"17\":\"e\",\"18\":\"n\",\"19\":\"_\",\"20\":\"i\",\"21\":\"s\",\"22\":\"_\",\"23\":\"i\",\"24\":\"n\",\"25\":\"_\",\"26\":\"G\",\"27\":\"a\",\"28\":\"m\",\"29\":\"e\",\"30\":\"_\",\"31\":\"o\",\"32\":\"f\",\"33\":\"_\",\"34\":\"T\",\"35\":\"h\",\"36\":\"r\",\"37\":\"o\",\"38\":\"n\",\"39\":\"e\",\"40\":\"s\",\"41\":\",\",\"42\":\"_\",\"43\":\"a\",\"44\":\"_\",\"45\":\"f\",\"46\":\"a\",\"47\":\"n\",\"48\":\"t\",\"49\":\"a\",\"50\":\"s\",\"51\":\"y\",\"52\":\"_\",\"53\":\"e\",\"54\":\"p\",\"55\":\"i\",\"56\":\"c\",\"57\":\"_\",\"58\":\"b\",\"59\":\"y\",\"60\":\"_\",\"61\":\"G\",\"62\":\"e\",\"63\":\"o\",\"64\":\"r\",\"65\":\"g\",\"66\":\"e\",\"67\":\"_\",\"68\":\"R\",\"69\":\".\",\"70\":\"R\",\"71\":\".\",\"72\":\"_\",\"73\":\"M\",\"74\":\"a\",\"75\":\"r\",\"76\":\"t\",\"77\":\"i\",\"78\":\"n\",\"79\":\".\",\"80\":\">\"},\"token_ids\":{\"0\":1,\"1\":72,\"2\":101,\"3\":105,\"4\":114,\"5\":105,\"6\":118,\"7\":125,\"8\":119,\"9\":36,\"10\":88,\"11\":101,\"12\":118,\"13\":107,\"14\":101,\"15\":118,\"16\":125,\"17\":105,\"18\":114,\"19\":36,\"20\":109,\"21\":119,\"22\":36,\"23\":109,\"24\":114,\"25\":36,\"26\":75,\"27\":101,\"28\":113,\"29\":105,\"30\":36,\"31\":115,\"32\":106,\"33\":36,\"34\":88,\"35\":108,\"36\":118,\"37\":115,\"38\":114,\"39\":105,\"40\":119,\"41\":48,\"42\":36,\"43\":101,\"44\":36,\"45\":106,\"46\":101,\"47\":114,\"48\":120,\"49\":101,\"50\":119,\"51\":125,\"52\":36,\"53\":105,\"54\":116,\"55\":109,\"56\":103,\"57\":36,\"58\":102,\"59\":125,\"60\":36,\"61\":75,\"62\":105,\"63\":115,\"64\":118,\"65\":107,\"66\":105,\"67\":36,\"68\":86,\"69\":50,\"70\":86,\"71\":50,\"72\":36,\"73\":81,\"74\":101,\"75\":118,\"76\":120,\"77\":109,\"78\":114,\"79\":50,\"80\":2},\"entropies\":{\"0\":3.3949158192,\"1\":2.1656451225,\"2\":2.3216569424,\"3\":2.8214058876,\"4\":1.5249242783,\"5\":0.0401624143,\"6\":0.0981037766,\"7\":0.0544578359,\"8\":0.3430138826,\"9\":1.0546212196,\"10\":0.25252828,\"11\":0.1494535804,\"12\":0.0624754503,\"13\":0.001355894,\"14\":0.0050173439,\"15\":0.0052358187,\"16\":0.0011725067,\"17\":0.0010307421,\"18\":1.0241208076,\"19\":3.6867966652,\"20\":0.4502205253,\"21\":0.0484119244,\"22\":2.2572875023,\"23\":0.3789347112,\"24\":1.0042934418,\"25\":2.9090054035,\"26\":1.8933598995,\"27\":1.3859074116,\"28\":0.3827198744,\"29\":0.2646365762,\"30\":1.7742085457,\"31\":0.0136727821,\"32\":0.0053820172,\"33\":0.5485631227,\"34\":0.2064044327,\"35\":0.0049266233,\"36\":0.0005439016,\"37\":0.0007023578,\"38\":0.0004170335,\"39\":0.0054524317,\"40\":1.1938130856,\"41\":0.0238215197,\"42\":3.1279797554,\"43\":1.3883389235,\"44\":3.0503094196,\"45\":1.695879817,\"46\":1.8551058769,\"47\":1.4570231438,\"48\":0.0047810897,\"49\":0.026396824,\"50\":0.6633765101,\"51\":0.3141393065,\"52\":2.8411159515,\"53\":1.143143177,\"54\":0.0520330966,\"55\":0.3398066461,\"56\":0.4140175879,\"57\":2.5563707352,\"58\":1.3370712996,\"59\":0.0227173548,\"60\":3.4447185993,\"61\":1.8576486111,\"62\":0.8189754486,\"63\":0.6776530743,\"64\":0.0677763447,\"65\":0.212713033,\"66\":0.1003480032,\"67\":0.1746164262,\"68\":0.4123829603,\"69\":0.5507118702,\"70\":0.1047425047,\"71\":0.0194335245,\"72\":0.001482119,\"73\":0.0009310447,\"74\":0.0002176317,\"75\":0.0076908777,\"76\":0.0003866984,\"77\":0.0008008487,\"78\":1.2395234108,\"79\":0.4564163089,\"80\":0.0000461392},\"patch\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":5,\"7\":5,\"8\":5,\"9\":5,\"10\":5,\"11\":5,\"12\":5,\"13\":5,\"14\":5,\"15\":5,\"16\":5,\"17\":5,\"18\":5,\"19\":5,\"20\":6,\"21\":6,\"22\":6,\"23\":7,\"24\":7,\"25\":7,\"26\":8,\"27\":9,\"28\":10,\"29\":10,\"30\":10,\"31\":11,\"32\":11,\"33\":11,\"34\":11,\"35\":11,\"36\":11,\"37\":11,\"38\":11,\"39\":11,\"40\":11,\"41\":11,\"42\":11,\"43\":12,\"44\":13,\"45\":14,\"46\":15,\"47\":16,\"48\":17,\"49\":17,\"50\":17,\"51\":17,\"52\":17,\"53\":18,\"54\":18,\"55\":18,\"56\":18,\"57\":18,\"58\":19,\"59\":20,\"60\":20,\"61\":21,\"62\":22,\"63\":22,\"64\":22,\"65\":22,\"66\":22,\"67\":22,\"68\":22,\"69\":22,\"70\":22,\"71\":22,\"72\":22,\"73\":22,\"74\":22,\"75\":22,\"76\":22,\"77\":22,\"78\":22,\"79\":22,\"80\":22},\"start\":{\"0\":1,\"1\":1,\"2\":1,\"3\":1,\"4\":1,\"5\":1,\"6\":0,\"7\":0,\"8\":0,\"9\":0,\"10\":0,\"11\":0,\"12\":0,\"13\":0,\"14\":0,\"15\":0,\"16\":0,\"17\":0,\"18\":0,\"19\":0,\"20\":1,\"21\":0,\"22\":0,\"23\":1,\"24\":0,\"25\":0,\"26\":1,\"27\":1,\"28\":1,\"29\":0,\"30\":0,\"31\":1,\"32\":0,\"33\":0,\"34\":0,\"35\":0,\"36\":0,\"37\":0,\"38\":0,\"39\":0,\"40\":0,\"41\":0,\"42\":0,\"43\":1,\"44\":1,\"45\":1,\"46\":1,\"47\":1,\"48\":1,\"49\":0,\"50\":0,\"51\":0,\"52\":0,\"53\":1,\"54\":0,\"55\":0,\"56\":0,\"57\":0,\"58\":1,\"59\":1,\"60\":0,\"61\":1,\"62\":1,\"63\":0,\"64\":0,\"65\":0,\"66\":0,\"67\":0,\"68\":0,\"69\":0,\"70\":0,\"71\":0,\"72\":0,\"73\":0,\"74\":0,\"75\":0,\"76\":0,\"77\":0,\"78\":0,\"79\":0,\"80\":0}}"} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e2ecd0d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[tool.isort] +profile = "black" +known_bytelatent = "bytelatent" +known_apps = "apps" +sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c6d87f1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +numpy +omegaconf +msgspec +rouge-score +sacrebleu +sentencepiece +tiktoken +fsspec +blobfile +wandb +viztracer +lm-eval +scipy +pynvml +datatrove +orjson +luigi +pydantic +altair +submitit +typer +rich diff --git a/setup/create_env.sh b/setup/create_env.sh new file mode 100644 index 0000000..9c7dad9 --- /dev/null +++ b/setup/create_env.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. + +#SBATCH --job-name=env_creation +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=0 +#SBATCH --time=01:00:00 + +# Exit immediately if a command exits with a non-zero status +set -e + +# Start timer +start_time=$(date +%s) + +# Get the current date +current_date=$(date +%y%m%d) + +# Create environment name with the current date +env_prefix=blt_$current_date + +# Create the conda environment + +source $CONDA_ROOT/etc/profile.d/conda.sh +conda create -n $env_prefix python=3.11 -y -c anaconda +conda activate $env_prefix + +echo "Currently in env $(which python)" + +# Install packages +pip install torch==2.5.0 xformers --index-url https://download.pytorch.org/whl/cu121 +pip install ninja +pip install --requirement requirements.txt + +# End timer +end_time=$(date +%s) + +# Calculate elapsed time in seconds +elapsed_time=$((end_time - start_time)) + +# Convert elapsed time to minutes +elapsed_minutes=$((elapsed_time / 60)) + +echo "Environment $env_prefix created and all packages installed successfully in $elapsed_minutes minutes!" diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py new file mode 100644 index 0000000..1aacf8f --- /dev/null +++ b/setup/download_prepare_hf_data.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import argparse +import os +import subprocess +import time + +import requests +from huggingface_hub import snapshot_download + + +def run_command(command): + print(f"Running: {command}") + subprocess.run(command, shell=True, check=True) + + +def download_dataset(repo_id, local_dir, allow_patterns): + print(f"Downloading dataset from {repo_id}...") + max_retries = 5 + retry_delay = 10 # seconds + for attempt in range(max_retries): + try: + snapshot_download( + repo_id, + repo_type="dataset", + local_dir=local_dir, + allow_patterns=allow_patterns, + resume_download=True, + max_workers=16, # Don't hesitate to increase this number to lower the download time + ) + break + except requests.exceptions.ReadTimeout: + if attempt < max_retries - 1: + print(f"Timeout occurred. Retrying in {retry_delay} seconds...") + time.sleep(retry_delay) + else: + raise + print(f"Dataset downloaded to {local_dir}") + + +def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): + from datatrove.executor import LocalPipelineExecutor + from datatrove.pipeline.readers import ParquetReader + from datatrove.pipeline.writers import JsonlWriter + + pipeline_exec = LocalPipelineExecutor( + pipeline=[ + ParquetReader( + src_dir, + file_progress=True, + doc_progress=True, + glob_pattern="**/*.parquet", + ), + JsonlWriter( + tgt_dir, + output_filename=dataset + ".chunk.${rank}.jsonl", + compression=None, + ), + ], + tasks=ntasks, + logging_dir=os.path.join(work_dir, "datatrove"), + ) + pipeline_exec.run() + + +def setup_terashuf(work_dir): + terashuf_dir = os.path.join(work_dir, "terashuf") + terashuf_executable = os.path.join(terashuf_dir, "terashuf") + + if os.path.exists(terashuf_executable): + print("terashuf executable already exists. Skipping setup.") + return terashuf_dir + + print("Setting up terashuf...") + run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}") + run_command(f"make -C {terashuf_dir}") + return terashuf_dir + + +def main(dataset, memory, data_dir, seed=42, nchunks=32): + # Configuration + repo_id = { + "fineweb_edu": "HuggingFaceFW/fineweb-edu", + "fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu", + "dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0", + "dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0", + }[dataset] + src_dir = f"{data_dir}/{dataset}" + out_dir = f"{src_dir}_shuffled" + os.makedirs(out_dir, exist_ok=True) + work_dir = src_dir # Directory of this Python file + prefix = f"{dataset}.chunk." + orig_extension = { + "fineweb_edu": ".jsonl", + "fineweb_edu_10bt": ".jsonl", + "dclm_baseline_1.0": ".jsonl.zst", + "dclm_baseline_1.0_10prct": ".jsonl.zst", + }[dataset] + cat_command = { + "fineweb_edu": "cat", + "fineweb_edu_10bt": "cat", + "dclm_baseline_1.0": "zstdcat", + "dclm_baseline_1.0_10prct": "zstdcat", + }[dataset] + allow_patterns = { + "fineweb_edu": None, + "fineweb_edu_10bt": "sample/10BT/*", + "dclm_baseline_1.0": "*.jsonl.zst", + "dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst", + }[dataset] + suffix = ".jsonl" + k_validation = 10000 # Number of lines to take from each chunk for validation + + # Setup terashuf + terashuf_dir = setup_terashuf(work_dir) + + # Download dataset + download_dataset(repo_id, src_dir, allow_patterns) + + if "fineweb" in dataset: + parquet_to_jsonl(dataset, work_dir, src_dir, src_dir) + + # Set up environment variables + os.environ["MEMORY"] = f"{memory}" + os.environ["SEED"] = f"{seed}" + + # Run the original shuffling and splitting command + terashuf_executable = os.path.join(terashuf_dir, "terashuf") + run_command( + f"ulimit -n 100000 && " + f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | " + f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}" + "; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;" + ) + + # Create validation set and remove lines from chunks + validation_file = f"{out_dir}/{dataset}.val{suffix}" + for i in range(nchunks): + chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}" + run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}") + run_command(f"sed -i '1,{k_validation}d' {chunk_file}") + + print("All tasks completed successfully!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset", type=str) + parser.add_argument("memory", type=float, default=8) + parser.add_argument("--data_dir", type=str, default="data") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--nchunks", type=int, default=32) + + args = parser.parse_args() + + main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks) diff --git a/setup/download_tokenizer.py b/setup/download_tokenizer.py new file mode 100644 index 0000000..fc9c8d5 --- /dev/null +++ b/setup/download_tokenizer.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import argparse +import os +from typing import Optional + +from requests.exceptions import HTTPError + +TOKENIZER = { + "llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"), + "llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"), + "gemma": ("google/gemma-2-9b", "tokenizer.model"), +} + + +def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None): + if tokenizer_name in TOKENIZER: + repo_id, filename = TOKENIZER[tokenizer_name] + + from huggingface_hub import hf_hub_download + + try: + hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=path_to_save, + local_dir_use_symlinks=False, + token=api_key if api_key else None, + ) + except HTTPError as e: + if e.response.status_code == 401: + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) + else: + raise e + else: + from tiktoken import get_encoding + + if "TIKTOKEN_CACHE_DIR" not in os.environ: + os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save + try: + get_encoding(tokenizer_name) + except ValueError: + print( + f"Tokenizer {tokenizer_name} not found. Please check the name and try again." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("tokenizer_name", type=str) + parser.add_argument("tokenizer_dir", type=str, default=8) + parser.add_argument("--api_key", type=str, default="") + args = parser.parse_args() + + main( + tokenizer_name=args.tokenizer_name, + path_to_save=args.tokenizer_dir, + api_key=args.api_key, + )