mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 08:27:45 +00:00
Initial commit
This commit is contained in:
commit
bcc039bb75
12
.github/workflows/black.yml
vendored
Normal file
12
.github/workflows/black.yml
vendored
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
name: Lint with Black
|
||||||
|
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: psf/black@stable
|
||||||
|
with:
|
||||||
|
version: "24.8.0"
|
10
.github/workflows/isort.yml
vendored
Normal file
10
.github/workflows/isort.yml
vendored
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
name: Lint with isort
|
||||||
|
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: isort/isort-action@master
|
168
.gitignore
vendored
Normal file
168
.gitignore
vendored
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
*.out
|
||||||
|
|
||||||
|
figures/
|
||||||
|
.vscode/
|
||||||
|
.DS_Store
|
||||||
|
|
8
.prettierrc
Normal file
8
.prettierrc
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
{
|
||||||
|
"overrides": [
|
||||||
|
{
|
||||||
|
"files": "*.yaml",
|
||||||
|
"options": { "tabWidth": 2 }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
80
CODE_OF_CONDUCT.md
Normal file
80
CODE_OF_CONDUCT.md
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
# Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
In the interest of fostering an open and welcoming environment, we as
|
||||||
|
contributors and maintainers pledge to make participation in our project and
|
||||||
|
our community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||||
|
level of experience, education, socio-economic status, nationality, personal
|
||||||
|
appearance, race, religion, or sexual identity and orientation.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to creating a positive environment
|
||||||
|
include:
|
||||||
|
|
||||||
|
* Using welcoming and inclusive language
|
||||||
|
* Being respectful of differing viewpoints and experiences
|
||||||
|
* Gracefully accepting constructive criticism
|
||||||
|
* Focusing on what is best for the community
|
||||||
|
* Showing empathy towards other community members
|
||||||
|
|
||||||
|
Examples of unacceptable behavior by participants include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||||
|
advances
|
||||||
|
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or electronic
|
||||||
|
address, without explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Our Responsibilities
|
||||||
|
|
||||||
|
Project maintainers are responsible for clarifying the standards of acceptable
|
||||||
|
behavior and are expected to take appropriate and fair corrective action in
|
||||||
|
response to any instances of unacceptable behavior.
|
||||||
|
|
||||||
|
Project maintainers have the right and responsibility to remove, edit, or
|
||||||
|
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||||
|
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||||
|
permanently any contributor for other behaviors that they deem inappropriate,
|
||||||
|
threatening, offensive, or harmful.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all project spaces, and it also applies when
|
||||||
|
an individual is representing the project or its community in public spaces.
|
||||||
|
Examples of representing a project or community include using an official
|
||||||
|
project e-mail address, posting via an official social media account, or acting
|
||||||
|
as an appointed representative at an online or offline event. Representation of
|
||||||
|
a project may be further defined and clarified by project maintainers.
|
||||||
|
|
||||||
|
This Code of Conduct also applies outside the project spaces when there is a
|
||||||
|
reasonable belief that an individual's behavior may have a negative impact on
|
||||||
|
the project or its community.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
||||||
|
complaints will be reviewed and investigated and will result in a response that
|
||||||
|
is deemed necessary and appropriate to the circumstances. The project team is
|
||||||
|
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||||
|
Further details of specific enforcement policies may be posted separately.
|
||||||
|
|
||||||
|
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||||
|
faith may face temporary or permanent repercussions as determined by other
|
||||||
|
members of the project's leadership.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||||
|
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||||
|
|
||||||
|
[homepage]: https://www.contributor-covenant.org
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see
|
||||||
|
https://www.contributor-covenant.org/faq
|
36
CONTRIBUTING.md
Normal file
36
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# Contributing to
|
||||||
|
|
||||||
|
We want to make contributing to this project as easy and transparent as
|
||||||
|
possible.
|
||||||
|
|
||||||
|
## Pull Requests
|
||||||
|
|
||||||
|
We actively welcome your pull requests.
|
||||||
|
|
||||||
|
1. Fork the repo and create your branch from `main`.
|
||||||
|
2. If you've added code that should be tested, add tests.
|
||||||
|
3. If you've changed APIs, update the documentation.
|
||||||
|
4. Ensure the test suite passes.
|
||||||
|
5. Make sure your code lints.
|
||||||
|
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||||
|
|
||||||
|
## Contributor License Agreement ("CLA")
|
||||||
|
|
||||||
|
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||||
|
to do this once to work on any of Meta's open source projects.
|
||||||
|
|
||||||
|
Complete your CLA here: <https://code.facebook.com/cla>
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
|
||||||
|
We use GitHub issues to track public bugs. Please ensure your description is
|
||||||
|
clear and has sufficient instructions to be able to reproduce the issue.
|
||||||
|
|
||||||
|
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||||
|
disclosure of security bugs. In those cases, please go through the process
|
||||||
|
outlined on that page and do not file a public issue.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
By contributing to BLT, you agree that your contributions will be licensed
|
||||||
|
under the LICENSE file in the root directory of this source tree.
|
28
LICENSE
Normal file
28
LICENSE
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright 2024 Meta
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without modification,
|
||||||
|
are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice,this list
|
||||||
|
of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its contributors may
|
||||||
|
be used to endorse or promote products derived from this software without specific
|
||||||
|
prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
|
||||||
|
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
||||||
|
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
|
||||||
|
SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
||||||
|
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
||||||
|
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
|
||||||
|
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
|
||||||
|
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
||||||
|
DAMAGE.
|
117
README.md
Normal file
117
README.md
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
# Byte Latent Transformer
|
||||||
|
|
||||||
|
This repository contains code for our paper: "Byte Latent Transformer: Patches Scale Better Than Tokens"
|
||||||
|
|
||||||
|
- [Paper Link](https://dl.fbaipublicfiles.com/blt/BLT__Patches_Scale_Better_Than_Tokens.pdf)
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
We introduce the Byte Latent Transformer architecture (BLTs), a new byte-level LLM architecture that
|
||||||
|
for the first time, matches tokenization-based LLM performance at scale, with significant improvements
|
||||||
|
in inference efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve
|
||||||
|
as the primary units of computation. Patches are segmented dynamically based on the entropy of the
|
||||||
|
next byte, allocating more compute and model capacity where there is more data complexity. The BLT
|
||||||
|
architecture includes new attention mechanisms to maximize the information flow between byte and
|
||||||
|
patch hidden representations and a new type of byte-sequence memory. We present the first scaling
|
||||||
|
study of byte-level models up to 8B parameters and 8T training bytes, showing for the first time
|
||||||
|
that we can train a model end-to-end at scale from bytes with no tokenization or other preprocessing.
|
||||||
|
Scaling trends reveal training and inference efficiency benefits from dynamically selecting very long
|
||||||
|
patches on average, along with qualitative improvements with reasoning and long tail generalization
|
||||||
|
from modeling byte-sequences.
|
||||||
|
|
||||||
|
![BLT Architecture Diagram](blt-figure.jpg)
|
||||||
|
|
||||||
|
## Development Status
|
||||||
|
|
||||||
|
We are actively updating the blt code to make it easier to reproduce our results.
|
||||||
|
Please file an issue and/or be patient while we make more of our code public!
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
The following commands launch a SLURM job that creates an environment for Meta Lingua.
|
||||||
|
The env creation should take around 5 minutes without counting downloads.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/facebookresearch/blt
|
||||||
|
cd blt
|
||||||
|
|
||||||
|
bash setup/create_env.sh
|
||||||
|
# or if you have access to a SLURM cluster
|
||||||
|
sbatch setup/create_env.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Once that is done your can activate the environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda activate blt_<date>
|
||||||
|
```
|
||||||
|
|
||||||
|
use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
|
||||||
|
This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python setup/download_prepare_hf_data.py fineweb_edu <MEMORY> --data_dir ./data --seed 42 --nchunks <NCHUNKS>
|
||||||
|
```
|
||||||
|
|
||||||
|
to download tokenizer (here llama3), use the folowing script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python setup/download_tokenizer.py llama3 <SAVE_PATH> --api_key <HUGGINGFACE_TOKEN>
|
||||||
|
```
|
||||||
|
|
||||||
|
Now launch a debug job to check if everything works. **The provided configurations are templates, you need to adapt them for them to work (change `dump_dir`, `data.root_dir`, `data.tokenizer.path`, etc ...)**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# stool stands for SLURM tool !
|
||||||
|
python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition=<partition>
|
||||||
|
# if you want to launch locally you can use torchrun
|
||||||
|
torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
|
||||||
|
# or you can also launch on 1 GPU
|
||||||
|
python -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
When using `stool`, if a job crashes, it can be relaunched using sbatch:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sbatch path/to/dump_dir/submit.slurm
|
||||||
|
```
|
||||||
|
|
||||||
|
## Linting
|
||||||
|
|
||||||
|
To lint, run the following command
|
||||||
|
|
||||||
|
```
|
||||||
|
bash dev/lint.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
The BLT is partially based on Meta Lingua, so consider citing it in addition to our BLT paper if you re-use our work.
|
||||||
|
|
||||||
|
BLT Paper Citation (will be updated to arXiv soon)
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{meta_blt,
|
||||||
|
author = {Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer},
|
||||||
|
title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
|
||||||
|
url = {https://github.com/facebookresearch/blt},
|
||||||
|
year = {2024}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Lingua Code
|
||||||
|
|
||||||
|
```
|
||||||
|
@misc{meta_lingua,
|
||||||
|
author = {Mathurin Videau, Badr Youbi Idrissi, Daniel Haziza, Luca Wehrstedt, Jade Copet, Olivier Teytaud, David Lopez-Paz},
|
||||||
|
title = {{Meta Lingua}: A minimal {PyTorch LLM} training library},
|
||||||
|
url = {https://github.com/facebookresearch/lingua},
|
||||||
|
year = {2024}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
The BLT code is partially based on Meta Lingia.
|
||||||
|
|
||||||
|
Meta Lingua is licensed under BSD-3-Clause license. Refer to the LICENSE file in the top level directory.
|
0
apps/__init__.py
Normal file
0
apps/__init__.py
Normal file
0
apps/main/__init__.py
Normal file
0
apps/main/__init__.py
Normal file
35
apps/main/configs/eval.yaml
Normal file
35
apps/main/configs/eval.yaml
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
name: "debug_evals"
|
||||||
|
# ckpt_dir: !!CHANGETHIS!!
|
||||||
|
# dump_dir: !!CHANGETHIS!!
|
||||||
|
generator:
|
||||||
|
max_tokens: 8192
|
||||||
|
dtype: bf16
|
||||||
|
temperature: 1.0
|
||||||
|
top_p: 0.95
|
||||||
|
harness:
|
||||||
|
tasks:
|
||||||
|
- hellaswag
|
||||||
|
- task: boolq
|
||||||
|
dataset_kwargs:
|
||||||
|
trust_remote_code: true
|
||||||
|
- task: nq_open
|
||||||
|
num_fewshot: 5
|
||||||
|
- piqa
|
||||||
|
- task: social_iqa
|
||||||
|
dataset_kwargs:
|
||||||
|
trust_remote_code: true
|
||||||
|
- triviaqa
|
||||||
|
- winogrande
|
||||||
|
- openbookqa
|
||||||
|
- arc_easy
|
||||||
|
- arc_challenge
|
||||||
|
- race
|
||||||
|
- commonsense_qa
|
||||||
|
# - coqa
|
||||||
|
- copa
|
||||||
|
- gsm8k
|
||||||
|
- bbh
|
||||||
|
- mmlu
|
||||||
|
- mmlu_pro
|
||||||
|
validation:
|
||||||
|
max_steps: 1000
|
87
apps/main/configs/llama_1B.yaml
Normal file
87
apps/main/configs/llama_1B.yaml
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
# dump_dir: !!!CHANGE_THIS!!!
|
||||||
|
name: large_lm
|
||||||
|
steps: 60_000
|
||||||
|
probe_freq: null
|
||||||
|
seed: 777
|
||||||
|
|
||||||
|
optim:
|
||||||
|
lr: 3e-3
|
||||||
|
weight_decay: 0.033
|
||||||
|
warmup: 5000
|
||||||
|
lr_min_ratio: 0.000001
|
||||||
|
clip: 1.0
|
||||||
|
|
||||||
|
distributed:
|
||||||
|
fsdp_type: full_shard
|
||||||
|
compile: true
|
||||||
|
model_dtype: bf16
|
||||||
|
matmul_allow_tf32: false
|
||||||
|
selective_activation_checkpointing: false
|
||||||
|
tp_size: 1
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 2048
|
||||||
|
n_layers: 25
|
||||||
|
n_heads: 16
|
||||||
|
|
||||||
|
data:
|
||||||
|
root_dir: data/shuffled
|
||||||
|
sources:
|
||||||
|
dclm_baseline_1.0: 100.0
|
||||||
|
batch_size: 4
|
||||||
|
prefetch_size: 1024
|
||||||
|
seq_len: 4096
|
||||||
|
n_views: 2
|
||||||
|
load_async: true
|
||||||
|
add_bos: true
|
||||||
|
add_eos: true
|
||||||
|
tokenizer:
|
||||||
|
name: tiktoken
|
||||||
|
path: tokenizers/cl_toplang_128k.tiktoken
|
||||||
|
|
||||||
|
profiling:
|
||||||
|
run: true
|
||||||
|
mem_warmup: 0
|
||||||
|
mem_steps: 4
|
||||||
|
profile_warmup: 100
|
||||||
|
profile_steps: 4
|
||||||
|
|
||||||
|
checkpoint:
|
||||||
|
dump:
|
||||||
|
every: 2500
|
||||||
|
keep: 3
|
||||||
|
eval:
|
||||||
|
every: 5000
|
||||||
|
keep: -1
|
||||||
|
|
||||||
|
logging:
|
||||||
|
freq: 1
|
||||||
|
|
||||||
|
async_eval_gpus: 8
|
||||||
|
eval:
|
||||||
|
harness:
|
||||||
|
tasks:
|
||||||
|
- hellaswag
|
||||||
|
- task: boolq
|
||||||
|
dataset_kwargs:
|
||||||
|
trust_remote_code: true
|
||||||
|
- piqa
|
||||||
|
- task: social_iqa
|
||||||
|
dataset_kwargs:
|
||||||
|
trust_remote_code: true
|
||||||
|
- winogrande
|
||||||
|
- openbookqa
|
||||||
|
- arc_easy
|
||||||
|
- arc_challenge
|
||||||
|
- race
|
||||||
|
- commonsense_qa
|
||||||
|
- copa
|
||||||
|
# - coqa
|
||||||
|
# - task: nq_open
|
||||||
|
# num_fewshot: 5
|
||||||
|
# - triviaqa
|
||||||
|
validation:
|
||||||
|
max_steps: 1000
|
||||||
|
generator:
|
||||||
|
max_tokens: 16384
|
||||||
|
dtype: bf16
|
95
apps/main/configs/llama_7B.yaml
Normal file
95
apps/main/configs/llama_7B.yaml
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
#python -m lingua.stool config=apps/main/configs/llama2_7B.yaml nodes=32 account=fair_amaia_cw_codegen qos=lowest
|
||||||
|
# dump_dir: !!!CHANGE_THIS!!!
|
||||||
|
name: "7b_baseline"
|
||||||
|
steps: 100_000
|
||||||
|
grad_acc_steps: 1
|
||||||
|
probe_freq: 100
|
||||||
|
|
||||||
|
seed: 777
|
||||||
|
optim:
|
||||||
|
lr: 1.0e-3
|
||||||
|
weight_decay: 0.1
|
||||||
|
warmup: 2000
|
||||||
|
lr_min_ratio: 0.000001
|
||||||
|
clip: 1.0
|
||||||
|
|
||||||
|
distributed:
|
||||||
|
fsdp_type: full_shard
|
||||||
|
compile: true
|
||||||
|
model_dtype: bf16
|
||||||
|
matmul_allow_tf32: false
|
||||||
|
selective_activation_checkpointing: false
|
||||||
|
tp_size: 1
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 4096
|
||||||
|
n_layers: 32
|
||||||
|
n_heads: 32
|
||||||
|
rope_theta: 100_000
|
||||||
|
ffn_dim_multiplier: 1.0
|
||||||
|
multiple_of: 256
|
||||||
|
|
||||||
|
data:
|
||||||
|
root_dir: data/shuffled
|
||||||
|
sources:
|
||||||
|
dclm_baseline_1.0: 1.0
|
||||||
|
batch_size: 2
|
||||||
|
prefetch_size: 1024
|
||||||
|
seq_len: 4096
|
||||||
|
n_views: 2
|
||||||
|
load_async: true
|
||||||
|
tokenizer:
|
||||||
|
name: tiktoken
|
||||||
|
path: tokenizers/cl_toplang_128k.tiktoken
|
||||||
|
|
||||||
|
profiling:
|
||||||
|
run: true
|
||||||
|
mem_warmup: 0
|
||||||
|
mem_steps: 4
|
||||||
|
profile_warmup: 100
|
||||||
|
profile_steps: 4
|
||||||
|
|
||||||
|
checkpoint:
|
||||||
|
dump:
|
||||||
|
every: 10000
|
||||||
|
keep: -1
|
||||||
|
eval:
|
||||||
|
every: 1000
|
||||||
|
keep: 3
|
||||||
|
|
||||||
|
logging:
|
||||||
|
freq: 1
|
||||||
|
|
||||||
|
async_eval_gpus: 8
|
||||||
|
eval:
|
||||||
|
dataset_dir: datasets/eval
|
||||||
|
harness:
|
||||||
|
tasks:
|
||||||
|
- hellaswag
|
||||||
|
- task: boolq
|
||||||
|
dataset_kwargs:
|
||||||
|
trust_remote_code: true
|
||||||
|
- piqa
|
||||||
|
- task: social_iqa
|
||||||
|
dataset_kwargs:
|
||||||
|
trust_remote_code: true
|
||||||
|
- winogrande
|
||||||
|
- openbookqa
|
||||||
|
- arc_easy
|
||||||
|
- arc_challenge
|
||||||
|
- race
|
||||||
|
- commonsense_qa
|
||||||
|
# - coqa
|
||||||
|
- copa
|
||||||
|
- mmlu
|
||||||
|
- mmlu_pro
|
||||||
|
# - task: nq_open
|
||||||
|
# num_fewshot: 5
|
||||||
|
# - triviaqa
|
||||||
|
# - gsm8k
|
||||||
|
# - bbh
|
||||||
|
validation:
|
||||||
|
max_steps: 1000
|
||||||
|
generator:
|
||||||
|
max_tokens: 8192
|
||||||
|
dtype: bf16
|
354
apps/main/eval.py
Normal file
354
apps/main/eval.py
Normal file
|
@ -0,0 +1,354 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lingua.args import dump_config
|
||||||
|
from lingua.data import init_choice_state, setup_sources
|
||||||
|
from lm_eval import simple_evaluate
|
||||||
|
from lm_eval.api.instance import Instance
|
||||||
|
from lm_eval.api.model import LM
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||||
|
from bytelatent.distributed import (
|
||||||
|
DistributedArgs,
|
||||||
|
dist_mean_dict,
|
||||||
|
get_global_rank,
|
||||||
|
get_world_size,
|
||||||
|
setup_torch_distributed,
|
||||||
|
)
|
||||||
|
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
||||||
|
|
||||||
|
from apps.main.generate import (
|
||||||
|
PackedCausalTransformerGenerator,
|
||||||
|
PackedCausalTransformerGeneratorArgs,
|
||||||
|
load_consolidated_model_and_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
EVAL_FOLDER_NAME = "{:010d}"
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMHarnessArgs:
|
||||||
|
tasks: Optional[List[Any]] = None
|
||||||
|
num_fewshot: Optional[int] = None
|
||||||
|
device: Optional[str] = None
|
||||||
|
use_cache: Optional[str] = None
|
||||||
|
cache_requests: bool = False
|
||||||
|
rewrite_requests_cache: bool = False
|
||||||
|
delete_requests_cache: bool = False
|
||||||
|
limit: Optional[Union[int, float]] = None
|
||||||
|
bootstrap_iters: int = 100000
|
||||||
|
check_integrity: bool = False
|
||||||
|
write_out: bool = False
|
||||||
|
log_samples: bool = True
|
||||||
|
system_instruction: Optional[str] = None
|
||||||
|
apply_chat_template: Union[bool, str] = False
|
||||||
|
fewshot_as_multiturn: bool = False
|
||||||
|
gen_kwargs: Optional[str] = None
|
||||||
|
verbosity: str = "INFO"
|
||||||
|
predict_only: bool = False
|
||||||
|
random_seed: int = 0
|
||||||
|
numpy_random_seed: int = 1234
|
||||||
|
torch_random_seed: int = 1234
|
||||||
|
fewshot_random_seed: int = 1234
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationArgs:
|
||||||
|
max_steps: Optional[int] = (
|
||||||
|
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
||||||
|
)
|
||||||
|
use_val_from_train_src: bool = True # Use the validation set from training sources
|
||||||
|
root_dir: str = ""
|
||||||
|
sources: List[str] = field(default_factory=list) # Other sources to eval on
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalArgs:
|
||||||
|
name: str = "evals"
|
||||||
|
dump_dir: Optional[str] = None
|
||||||
|
metric_log_dir: Optional[str] = None
|
||||||
|
ckpt_dir: str = ""
|
||||||
|
generator: PackedCausalTransformerGeneratorArgs = field(
|
||||||
|
default_factory=PackedCausalTransformerGeneratorArgs
|
||||||
|
)
|
||||||
|
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
|
||||||
|
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
|
||||||
|
|
||||||
|
wandb: Optional[Any] = None
|
||||||
|
|
||||||
|
global_step: Optional[int] = None # for in-training evaluation
|
||||||
|
|
||||||
|
|
||||||
|
def all_dicts_same(dict_list):
|
||||||
|
if not dict_list: # Check if the list is empty
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Compare each dictionary to the first one
|
||||||
|
first_dict = dict_list[0]
|
||||||
|
return all(d == first_dict for d in dict_list)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAccelerator:
|
||||||
|
def gather(self, tensor):
|
||||||
|
l = [torch.zeros_like(tensor) for _ in range(get_world_size())]
|
||||||
|
torch.distributed.all_gather(l, tensor)
|
||||||
|
return torch.stack(l)
|
||||||
|
|
||||||
|
def wait_for_everyone(self):
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
# Light wrapper around generator for lm-eval harness
|
||||||
|
class EvalHarnessLM(LM):
|
||||||
|
def __init__(self, generator):
|
||||||
|
super().__init__()
|
||||||
|
self.generator = generator
|
||||||
|
self.accelerator = MockAccelerator()
|
||||||
|
self._rank = get_global_rank()
|
||||||
|
self._world_size = get_world_size()
|
||||||
|
self.device = generator.device
|
||||||
|
|
||||||
|
def generate_until(self, requests: List[Instance]) -> List[str]:
|
||||||
|
prompts, gen_args = zip(*[req.args for req in requests])
|
||||||
|
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
|
||||||
|
gen_args = gen_args[0]
|
||||||
|
temperature = gen_args.get("temperature", 0.0)
|
||||||
|
top_p = gen_args.get("top_p", None)
|
||||||
|
top_k = gen_args.get("top_k", None)
|
||||||
|
until = gen_args.get("until", [])
|
||||||
|
|
||||||
|
self.generator.temperature = temperature
|
||||||
|
self.generator.top_p = top_p
|
||||||
|
self.generator.top_k = top_k
|
||||||
|
self.generator.until = until
|
||||||
|
generations, _, _ = self.generator.generate(prompts)
|
||||||
|
filtered_gen = []
|
||||||
|
for g in generations:
|
||||||
|
for e in until:
|
||||||
|
g = g.replace(e, "")
|
||||||
|
filtered_gen.append(g)
|
||||||
|
return filtered_gen
|
||||||
|
|
||||||
|
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
||||||
|
prompts, continuations = zip(*[req.args for req in requests])
|
||||||
|
inputs = [req.args[0] + req.args[1] for req in requests]
|
||||||
|
max_gen_len = self.generator.max_gen_len
|
||||||
|
# We temporarily lower max gen len
|
||||||
|
self.generator.max_gen_len = 1
|
||||||
|
_, lls, greedy = self.generator.generate(inputs)
|
||||||
|
results = []
|
||||||
|
for p, ll, gr in zip(prompts, lls, greedy):
|
||||||
|
p_len = len(
|
||||||
|
self.generator.tokenizer.encode(p, add_bos=False, add_eos=False)
|
||||||
|
)
|
||||||
|
results.append((ll[p_len:].sum().item(), gr[p_len:].all().item()))
|
||||||
|
|
||||||
|
self.generator.max_gen_len = max_gen_len
|
||||||
|
return results
|
||||||
|
|
||||||
|
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
||||||
|
prompts = [req.args[0] for req in requests]
|
||||||
|
max_gen_len = self.generator.max_gen_len
|
||||||
|
# We temporarily lower max gen len
|
||||||
|
self.generator.max_gen_len = 1
|
||||||
|
_, lls, _ = self.generator.generate(prompts)
|
||||||
|
results = []
|
||||||
|
for ll in lls:
|
||||||
|
results.append((ll.sum().item(),))
|
||||||
|
self.generator.max_gen_len = max_gen_len
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||||
|
srcs = {}
|
||||||
|
for src in val_args.sources:
|
||||||
|
path = os.path.join(val_args.root_dir, src)
|
||||||
|
srcs[path] = 1.0
|
||||||
|
for src in train_cfg.data.sources:
|
||||||
|
path = os.path.join(train_cfg.data.root_dir, src)
|
||||||
|
srcs[path] = 1.0
|
||||||
|
|
||||||
|
multi_state = init_choice_state(
|
||||||
|
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
||||||
|
)
|
||||||
|
path_to_iter = setup_sources(multi_state)
|
||||||
|
|
||||||
|
max_gen_len = generator.max_gen_len
|
||||||
|
# We temporarily lower max gen len
|
||||||
|
generator.max_gen_len = 1
|
||||||
|
|
||||||
|
all_val_metrics = {}
|
||||||
|
for src in path_to_iter:
|
||||||
|
jsonl_iterator = path_to_iter[src]
|
||||||
|
texts = []
|
||||||
|
logger.info(f"Running validation on {src}...")
|
||||||
|
for step, (content, state) in enumerate(jsonl_iterator):
|
||||||
|
if state["current_iter"] > 0 or (
|
||||||
|
val_args.max_steps is not None and step >= val_args.max_steps
|
||||||
|
):
|
||||||
|
break
|
||||||
|
content_key = "text" if ("text" in content) else "content"
|
||||||
|
texts.append(content[content_key])
|
||||||
|
|
||||||
|
_, loglikelihood, _ = generator.generate(texts)
|
||||||
|
|
||||||
|
metrics = defaultdict(list)
|
||||||
|
for i, ll in enumerate(loglikelihood):
|
||||||
|
tmp = ll.sum().item()
|
||||||
|
metrics["nll"].append(tmp)
|
||||||
|
metrics["nll_per_token"].append(tmp / len(ll))
|
||||||
|
metrics["nll_per_char"].append(tmp / len(texts[i]))
|
||||||
|
|
||||||
|
metrics["avg_seqlen"].append(len(ll))
|
||||||
|
|
||||||
|
for m in metrics:
|
||||||
|
metrics[m] = sum(metrics[m]) / len(metrics[m])
|
||||||
|
metrics.update(dist_mean_dict(metrics))
|
||||||
|
logger.info(f"Validation on {src} done. Metrics: {metrics}")
|
||||||
|
|
||||||
|
name = os.path.basename(src)
|
||||||
|
if name in all_val_metrics:
|
||||||
|
logger.warning(
|
||||||
|
f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
|
||||||
|
)
|
||||||
|
name = f"{name}_1"
|
||||||
|
all_val_metrics[name] = metrics
|
||||||
|
|
||||||
|
generator.max_gen_len = max_gen_len
|
||||||
|
|
||||||
|
return all_val_metrics
|
||||||
|
|
||||||
|
|
||||||
|
def launch_eval(cfg: EvalArgs):
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
setup_torch_distributed(DistributedArgs())
|
||||||
|
if (
|
||||||
|
Path(cfg.ckpt_dir).exists()
|
||||||
|
and (Path(cfg.ckpt_dir) / "params.json").exists()
|
||||||
|
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
|
||||||
|
):
|
||||||
|
consolidate_path = Path(cfg.ckpt_dir)
|
||||||
|
else:
|
||||||
|
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
|
||||||
|
if not consolidate_path.exists() and get_global_rank() == 0:
|
||||||
|
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
|
||||||
|
|
||||||
|
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
|
||||||
|
|
||||||
|
consolidate_path = str(consolidate_path)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
logger.info("Loading model")
|
||||||
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||||
|
consolidate_path,
|
||||||
|
model_cls=LMTransformer,
|
||||||
|
model_args_cls=LMTransformerArgs,
|
||||||
|
)
|
||||||
|
logger.info("Model loaded")
|
||||||
|
model.eval()
|
||||||
|
generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
|
||||||
|
|
||||||
|
wrap = EvalHarnessLM(generator)
|
||||||
|
results = simple_evaluate(wrap, **asdict(cfg.harness))
|
||||||
|
val_results = None
|
||||||
|
if cfg.validation:
|
||||||
|
val_results = eval_on_val(generator, cfg.validation, train_cfg)
|
||||||
|
if get_global_rank() == 0:
|
||||||
|
with open(Path(cfg.dump_dir) / "results.json", "w") as f:
|
||||||
|
f.write(json.dumps(results))
|
||||||
|
logger.info(f"All evaluation results: {results['results']}")
|
||||||
|
if val_results is not None:
|
||||||
|
with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
|
||||||
|
f.write(json.dumps(val_results))
|
||||||
|
logger.info(f"All validation results: {val_results}")
|
||||||
|
if cfg.metric_log_dir and get_global_rank() == 0:
|
||||||
|
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
|
||||||
|
|
||||||
|
logger.info(f"Writing metric logs to {metric_log_path}")
|
||||||
|
timestamp = {
|
||||||
|
"created_at": datetime.utcnow().isoformat(),
|
||||||
|
}
|
||||||
|
if cfg.global_step is not None:
|
||||||
|
timestamp["global_step"] = cfg.global_step
|
||||||
|
print(
|
||||||
|
json.dumps(timestamp | results["results"]),
|
||||||
|
file=open(metric_log_path, mode="a"),
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
|
||||||
|
if val_results is not None:
|
||||||
|
print(
|
||||||
|
json.dumps(timestamp | val_results),
|
||||||
|
file=open(val_log_path, mode="a"),
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
del generator
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
||||||
|
This accepts arguments as a dot list
|
||||||
|
So if the dataclass looks like
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DummyArgs:
|
||||||
|
name: str
|
||||||
|
model: LMTransformerArgsgs
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMTransformerArgsgs:
|
||||||
|
dim: int
|
||||||
|
|
||||||
|
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
||||||
|
or just name=tictac for top level attributes.
|
||||||
|
|
||||||
|
The behavior here is as follows:
|
||||||
|
1. We instantiate EvalArgs with its default values
|
||||||
|
2. We override those default values with the ones in the provided config file
|
||||||
|
3. We override the result with the additional arguments provided through command line
|
||||||
|
|
||||||
|
For example, if the config is the following
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 128
|
||||||
|
n_layers: 4
|
||||||
|
|
||||||
|
and you call eval.py with eval.py model.dim=64
|
||||||
|
|
||||||
|
Then the final TrainArgs will have
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 64
|
||||||
|
n_layers: 4
|
||||||
|
|
||||||
|
Plus all the default values in EvalArgs dataclass.
|
||||||
|
"""
|
||||||
|
cli_args = OmegaConf.from_cli()
|
||||||
|
file_cfg = OmegaConf.load(cli_args.config)
|
||||||
|
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
||||||
|
del cli_args.config
|
||||||
|
|
||||||
|
default_cfg = OmegaConf.structured(EvalArgs())
|
||||||
|
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||||
|
cfg = OmegaConf.to_object(cfg)
|
||||||
|
launch_eval(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
463
apps/main/generate.py
Normal file
463
apps/main/generate.py
Normal file
|
@ -0,0 +1,463 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lingua.args import dataclass_from_dict
|
||||||
|
from lingua.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
|
from lingua.tokenizers.build_tokenizer import build_tokenizer
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.attention.flex_attention import create_block_mask
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from bytelatent.base_transformer import (
|
||||||
|
Attention,
|
||||||
|
causal_mask,
|
||||||
|
generate_doc_mask_mod,
|
||||||
|
lengths_to_local_ids,
|
||||||
|
lengths_to_start_ids,
|
||||||
|
)
|
||||||
|
from bytelatent.checkpoint import CONSOLIDATE_NAME
|
||||||
|
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort[mask] = 0.0
|
||||||
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_k(probs, k):
|
||||||
|
topk_value, _ = torch.topk(probs, k) # batch_sz x topk
|
||||||
|
min_value_top_k = topk_value[:, [-1]]
|
||||||
|
probs[probs < min_value_top_k] = 0.0
|
||||||
|
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
|
||||||
|
shape = logits.shape
|
||||||
|
logits = logits.flatten(end_dim=-2)
|
||||||
|
if temperature > 0.0:
|
||||||
|
probs = torch.softmax(logits / temperature, dim=-1)
|
||||||
|
|
||||||
|
if top_p is not None:
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
elif top_k is not None:
|
||||||
|
next_token = sample_top_k(probs, top_k)
|
||||||
|
else:
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits, dim=-1)
|
||||||
|
return next_token.view(shape[:-1])
|
||||||
|
|
||||||
|
|
||||||
|
def pack_prompts(prompts: List[int]):
|
||||||
|
res = []
|
||||||
|
lengths = []
|
||||||
|
for i, p in enumerate(prompts):
|
||||||
|
p = torch.tensor(p, dtype=torch.long)
|
||||||
|
l = p.size(0)
|
||||||
|
res.append(p)
|
||||||
|
lengths.append(l)
|
||||||
|
lengths = torch.tensor(lengths, dtype=torch.long)
|
||||||
|
res = torch.cat(res)
|
||||||
|
return res, lengths
|
||||||
|
|
||||||
|
|
||||||
|
def batch_prompts(prompts, max_elements, lengths=None):
|
||||||
|
batches = []
|
||||||
|
current_batch = []
|
||||||
|
current_count = 0
|
||||||
|
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
prt = prompts[i]
|
||||||
|
prompt_size = len(prt) if lengths is None else lengths[i]
|
||||||
|
if current_count + prompt_size <= max_elements:
|
||||||
|
current_batch.append(prt)
|
||||||
|
current_count += prompt_size
|
||||||
|
else:
|
||||||
|
if current_batch: # Add the current batch to batches
|
||||||
|
batches.append(current_batch)
|
||||||
|
# Start a new batch with the current prompt
|
||||||
|
current_batch = [prt]
|
||||||
|
current_count = prompt_size
|
||||||
|
|
||||||
|
# Add the last batch if it contains any prompts
|
||||||
|
if current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache(nn.Module):
|
||||||
|
def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device):
|
||||||
|
super().__init__()
|
||||||
|
shape = (bsz, seqlen, n_heads, head_dim)
|
||||||
|
self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device))
|
||||||
|
self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device))
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.k_cache.zero_()
|
||||||
|
self.v_cache.zero_()
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def update(self, k_val, v_val, tok_idx):
|
||||||
|
# input_pos: [B], k_val: [B, S, H, D]
|
||||||
|
self.k_cache.index_copy_(1, self.offset + tok_idx, k_val)
|
||||||
|
self.v_cache.index_copy_(1, self.offset + tok_idx, v_val)
|
||||||
|
return self.k_cache, self.v_cache
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PackedCausalTransformerGeneratorArgs:
|
||||||
|
temperature: float = 0.0
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[float] = None
|
||||||
|
max_gen_len: int = 512 # Maximum number of tokens to generate
|
||||||
|
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
|
||||||
|
max_prompt_len: Optional[int] = None
|
||||||
|
until: List[str] = field(default_factory=list)
|
||||||
|
compile_prefilling: bool = False
|
||||||
|
reduce_generation_overhead: bool = False
|
||||||
|
show_progress: bool = False
|
||||||
|
dtype: Optional[str] = "bf16"
|
||||||
|
device: Optional[str] = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
class PackedCausalTransformerGenerator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: PackedCausalTransformerGeneratorArgs,
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This class wraps a causal transformer model with its corresponding tokenizer
|
||||||
|
and provides an efficient way to pack prompts together and do generation on
|
||||||
|
the packed sequence.
|
||||||
|
|
||||||
|
For example, if we had the prompts "Hello, I am a " and "Initiating calibration "
|
||||||
|
Then this class will concatenate those sequence (pack them together)
|
||||||
|
"Hello, I am a Initiating calibration"
|
||||||
|
And make the necessary attention masks such that a sequence only attends to itself
|
||||||
|
during prefilling and generation.
|
||||||
|
|
||||||
|
This class creates a fixed size cache of size max_tokens or sum of prompt sizes
|
||||||
|
+ the max number of generated tokens per sequence.
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.temperature = cfg.temperature
|
||||||
|
self.top_p = cfg.top_p
|
||||||
|
self.top_k = cfg.top_k
|
||||||
|
|
||||||
|
self.max_gen_len = cfg.max_gen_len
|
||||||
|
self.max_tokens = cfg.max_tokens
|
||||||
|
self.max_prompt_len = cfg.max_prompt_len
|
||||||
|
self.until = cfg.until
|
||||||
|
self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
|
||||||
|
self.device = cfg.device
|
||||||
|
|
||||||
|
# Compile if necessary
|
||||||
|
self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
|
||||||
|
self.generate_next_token = torch.compile(
|
||||||
|
self.generate_next_token,
|
||||||
|
mode="reduce-overhead",
|
||||||
|
disable=not cfg.reduce_generation_overhead,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.show_progress = cfg.show_progress
|
||||||
|
self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
|
||||||
|
|
||||||
|
self.prefill_doc_id, self.prefill_tok_id = None, None
|
||||||
|
self.padded_doc_id, self.padded_tok_id = None, None
|
||||||
|
self.current_doc_id, self.current_tok_id = None, None
|
||||||
|
self.padded_doc_start = None
|
||||||
|
self.prefill_mask = None
|
||||||
|
|
||||||
|
def clear_cache(self, offset):
|
||||||
|
for module in self.model.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
if not hasattr(module, "kv_cache"):
|
||||||
|
module.kv_cache = KVCache(
|
||||||
|
1,
|
||||||
|
self.max_tokens,
|
||||||
|
module.n_kv_heads,
|
||||||
|
module.head_dim,
|
||||||
|
self.dtype,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
module.kv_cache.offset = offset
|
||||||
|
|
||||||
|
@torch.compiler.disable
|
||||||
|
def setup_prefilling(self, lengths: torch.Tensor):
|
||||||
|
# The KV cache is a fixed size tensor of size max_tokens that we need
|
||||||
|
# to update in order to do correct autoregressive generation.
|
||||||
|
|
||||||
|
# Here we will generate token by token but on multiple sequences
|
||||||
|
# at once. To do so, we need to have an attention mask that makes
|
||||||
|
# each sequence independent.
|
||||||
|
|
||||||
|
# Each sequence will write to its allocated space in the KV Cache.
|
||||||
|
# We allocate len(seq) + max_gen_len to each sequence in the cache.
|
||||||
|
|
||||||
|
# We will generate max_gen_len for each document
|
||||||
|
padded_lengths = lengths + self.max_gen_len
|
||||||
|
max_tokens = self.max_tokens or padded_lengths.sum().item()
|
||||||
|
# The last document might have more padding to fill up to max_tokens
|
||||||
|
padded_lengths[-1] += max_tokens - padded_lengths.sum()
|
||||||
|
|
||||||
|
# This is the start index in the cache for each document
|
||||||
|
self.padded_doc_start = lengths_to_start_ids(padded_lengths)
|
||||||
|
# For example with ab--123--cdef--
|
||||||
|
# this would be 0, 4, 9 if max_gen_len is 2
|
||||||
|
|
||||||
|
# We repeat interleave to align with tokens for prefilling
|
||||||
|
# Ex: ab--123--cdef--
|
||||||
|
# 000044444999999
|
||||||
|
prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths)
|
||||||
|
# This offset will make sure the tokens are written to the
|
||||||
|
# correct positions in the cache during prefilling
|
||||||
|
|
||||||
|
# We either init the cache or clear it by resetting the offset to prefill_offset
|
||||||
|
self.clear_cache(prefill_offset)
|
||||||
|
|
||||||
|
# The prefilling mask looks like the following for
|
||||||
|
# the two packed sequences ab and 123 : ab123
|
||||||
|
# Where spaces are empty cache positions
|
||||||
|
# keys
|
||||||
|
# ab---123---
|
||||||
|
# queries a 10000000000
|
||||||
|
# b 11000000000
|
||||||
|
# 1 00000100000
|
||||||
|
# 2 00000110000
|
||||||
|
# 3 00000111000
|
||||||
|
# We make sure to skip the empty cache positions
|
||||||
|
# and only attend to positions within the same sequence
|
||||||
|
doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths)
|
||||||
|
self.prefill_mask = create_block_mask(
|
||||||
|
doc_mask_mod, 1, None, lengths.sum(), max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# This creates the prefilling token ids which look like
|
||||||
|
# the following for the packed sequence abcdefg1234
|
||||||
|
# abcdefg1234
|
||||||
|
# 01234560123
|
||||||
|
# The token id gives us the position within each sequence
|
||||||
|
# This is used to compute ROPE and to update the cache
|
||||||
|
# At each forward pass the current tokens are written to
|
||||||
|
# offset + tok_id
|
||||||
|
self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths)
|
||||||
|
|
||||||
|
# This creates the padded token and document ids
|
||||||
|
# which look like the following for the packed sequence ab123
|
||||||
|
# ab---123--- ab---123---
|
||||||
|
# padded_doc_id 00000111111 padded_tok_id 01234012345
|
||||||
|
# This will later be useful for the attention mask at generation
|
||||||
|
self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths)
|
||||||
|
|
||||||
|
@torch.compiler.disable
|
||||||
|
def setup_generation(self, lengths):
|
||||||
|
# KV Cache offset is set to the start of the padded documents
|
||||||
|
for module in self.model.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.kv_cache.offset = self.padded_doc_start
|
||||||
|
# The token ids during generations correspond to the lengths of each doc
|
||||||
|
# current_tok_id will be incremented during generation
|
||||||
|
self.current_tok_id = lengths.clone()
|
||||||
|
# Since we're generating one token per document
|
||||||
|
# the document id is just an arange
|
||||||
|
self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device)
|
||||||
|
|
||||||
|
# From here on some methods for generation
|
||||||
|
def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
|
||||||
|
# Prefilling is done by taking multiple packed sequences and
|
||||||
|
# doing block diagonal attention on them so they remain independent
|
||||||
|
self.setup_prefilling(lengths=lengths)
|
||||||
|
prefill_out = self.model.forward(
|
||||||
|
tokens,
|
||||||
|
tok_idx=self.prefill_tok_id,
|
||||||
|
mask=self.prefill_mask,
|
||||||
|
attn_impl="flex_attention",
|
||||||
|
)
|
||||||
|
self.setup_generation(lengths=lengths)
|
||||||
|
return prefill_out
|
||||||
|
|
||||||
|
def generate_next_token(self, current_token):
|
||||||
|
# Since we're doing generation with multiple sequences at once
|
||||||
|
# we need to ignore tokens and cache entries from other sequences
|
||||||
|
# or in the future.
|
||||||
|
# Example mask :
|
||||||
|
# keys
|
||||||
|
# abc--1234--
|
||||||
|
# queries c 11100000000
|
||||||
|
# 4 00000111100
|
||||||
|
|
||||||
|
# mask shape : (n_seqs, cache_size)
|
||||||
|
doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0)
|
||||||
|
caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0)
|
||||||
|
mask = doc_mask & caus_mask
|
||||||
|
out = self.model.forward(
|
||||||
|
current_token,
|
||||||
|
tok_idx=self.current_tok_id, # n_seqs
|
||||||
|
mask=mask,
|
||||||
|
attn_impl="sdpa",
|
||||||
|
)
|
||||||
|
self.current_tok_id += 1
|
||||||
|
return out
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(self, prompts):
|
||||||
|
# Tokenize
|
||||||
|
prompts = [
|
||||||
|
self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts
|
||||||
|
]
|
||||||
|
# Truncate
|
||||||
|
max_seqlen = (
|
||||||
|
self.max_tokens
|
||||||
|
if not hasattr(self.model, "max_seqlen")
|
||||||
|
else self.model.max_seqlen
|
||||||
|
)
|
||||||
|
max_prompt_len = self.max_prompt_len or min(
|
||||||
|
max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len
|
||||||
|
)
|
||||||
|
prompts = [p[-max_prompt_len:] for p in prompts]
|
||||||
|
# Account for the generation in lengths
|
||||||
|
padded_lengths = [len(p) + self.max_gen_len for p in prompts]
|
||||||
|
generation = []
|
||||||
|
loglikelihood = []
|
||||||
|
greedy = []
|
||||||
|
it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths)
|
||||||
|
if self.show_progress:
|
||||||
|
it = tqdm(it)
|
||||||
|
for batch in it:
|
||||||
|
n_seqs = len(batch)
|
||||||
|
generated_tokens = [[] for _ in range(n_seqs)]
|
||||||
|
is_done = [False for _ in range(n_seqs)]
|
||||||
|
packed_batch, lengths = pack_prompts(batch)
|
||||||
|
packed_batch, lengths = packed_batch.cuda(), lengths.cuda()
|
||||||
|
n_seqs = lengths.size(0)
|
||||||
|
|
||||||
|
# Prefilling cache
|
||||||
|
prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths)
|
||||||
|
# Selecting last token in each prompt
|
||||||
|
all_tokens = sample_tokens(
|
||||||
|
prompt_logits, self.temperature, self.top_p, self.top_k
|
||||||
|
)
|
||||||
|
start_token = all_tokens[:, lengths.cumsum(0) - 1]
|
||||||
|
|
||||||
|
for seq_id, tok in enumerate(start_token.squeeze(0).tolist()):
|
||||||
|
generated_tokens[seq_id].append(tok)
|
||||||
|
|
||||||
|
current_token = start_token
|
||||||
|
for i in range(1, self.max_gen_len):
|
||||||
|
|
||||||
|
next_logits = self.generate_next_token(current_token)
|
||||||
|
next_token = sample_tokens(
|
||||||
|
next_logits.clone(), self.temperature, self.top_p, self.top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
for seq_id, tok in enumerate(next_token.squeeze(0).tolist()):
|
||||||
|
if not is_done[seq_id]:
|
||||||
|
generated_tokens[seq_id].append(tok)
|
||||||
|
current_end_str = self.tokenizer.decode(
|
||||||
|
generated_tokens[seq_id][-self.max_until_size :]
|
||||||
|
)
|
||||||
|
contains_end_string = any(
|
||||||
|
[e in current_end_str for e in self.until]
|
||||||
|
)
|
||||||
|
is_done[seq_id] = (
|
||||||
|
contains_end_string or tok == self.tokenizer.eos_id
|
||||||
|
)
|
||||||
|
if all(is_done):
|
||||||
|
break
|
||||||
|
|
||||||
|
current_token = next_token
|
||||||
|
|
||||||
|
generation.extend([self.tokenizer.decode(g) for g in generated_tokens])
|
||||||
|
|
||||||
|
for p, logit in zip(
|
||||||
|
batch, prompt_logits.squeeze(0).split(lengths.tolist())
|
||||||
|
):
|
||||||
|
x = logit[:-1]
|
||||||
|
y = torch.tensor(p[1:], device=x.device)
|
||||||
|
loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu())
|
||||||
|
greedy.append((x.argmax(dim=-1) == y).cpu())
|
||||||
|
|
||||||
|
return generation, loglikelihood, greedy
|
||||||
|
|
||||||
|
|
||||||
|
def load_consolidated_model_and_tokenizer(
|
||||||
|
consolidated_path,
|
||||||
|
model_cls=LMTransformer,
|
||||||
|
model_args_cls=LMTransformerArgs,
|
||||||
|
):
|
||||||
|
ckpt_path = Path(consolidated_path)
|
||||||
|
config = ckpt_path / "params.json"
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
|
||||||
|
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
||||||
|
config.distributed.model_dtype
|
||||||
|
]
|
||||||
|
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
|
||||||
|
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
|
||||||
|
model = model_cls(model_args)
|
||||||
|
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
|
||||||
|
model.load_state_dict(st_dict["model"])
|
||||||
|
model = model.cuda().eval()
|
||||||
|
for param in model.parameters():
|
||||||
|
param.data = param.data.to(dtype=param_dtype)
|
||||||
|
return model, tokenizer, config
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Load CLI arguments (overrides) and combine with a YAML config
|
||||||
|
cfg = OmegaConf.from_cli()
|
||||||
|
gen_cfg = dataclass_from_dict(
|
||||||
|
PackedCausalTransformerGeneratorArgs, cfg, strict=False
|
||||||
|
)
|
||||||
|
print(cfg)
|
||||||
|
|
||||||
|
model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
|
||||||
|
|
||||||
|
generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
|
||||||
|
|
||||||
|
# Allow multiple prompts
|
||||||
|
prompts = []
|
||||||
|
while True:
|
||||||
|
prompt = input("Enter a prompt (or press enter to finish): ")
|
||||||
|
if not prompt:
|
||||||
|
break
|
||||||
|
prompts.append(prompt)
|
||||||
|
|
||||||
|
# Start generation
|
||||||
|
start_time = time.time()
|
||||||
|
generation, loglikelihood, greedy = generator.generate(prompts)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
# Calculate tokens per second
|
||||||
|
total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
|
||||||
|
tokens_per_second = total_tokens / (end_time - start_time)
|
||||||
|
|
||||||
|
# Display the results
|
||||||
|
for i, gen in enumerate(generation):
|
||||||
|
print(f"\nPrompt {i+1}: {prompts[i]}")
|
||||||
|
print(f"Generated Text: {gen}")
|
||||||
|
|
||||||
|
print(f"\nTokens per second: {tokens_per_second:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
654
apps/main/lingua_train.py
Normal file
654
apps/main/lingua_train.py
Normal file
|
@ -0,0 +1,654 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import wandb
|
||||||
|
import xformers.profiler
|
||||||
|
from lingua.args import dataclass_from_dict, dump_config, flatten_dict
|
||||||
|
from lingua.data import (
|
||||||
|
DataArgs,
|
||||||
|
PackTokensState,
|
||||||
|
build_dataloader_from_args,
|
||||||
|
init_dataloader_state_from_args,
|
||||||
|
)
|
||||||
|
from lingua.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from torch.distributed._tensor import DTensor
|
||||||
|
from torch.distributed.checkpoint.stateful import Stateful
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
|
from bytelatent.checkpoint import (
|
||||||
|
CheckpointArgs,
|
||||||
|
CheckpointManager,
|
||||||
|
load_from_checkpoint,
|
||||||
|
)
|
||||||
|
from bytelatent.distributed import (
|
||||||
|
DistributedArgs,
|
||||||
|
EnvironmentArgs,
|
||||||
|
check_model_value_range,
|
||||||
|
clean_env,
|
||||||
|
dist_mean_dict,
|
||||||
|
get_device_mesh,
|
||||||
|
get_is_master,
|
||||||
|
get_world_size,
|
||||||
|
init_signal_handler,
|
||||||
|
parallelize_model,
|
||||||
|
requeue_slurm_job,
|
||||||
|
setup_env,
|
||||||
|
setup_torch_distributed,
|
||||||
|
)
|
||||||
|
from bytelatent.logger import init_logger
|
||||||
|
from bytelatent.metrics import (
|
||||||
|
GPUMemoryMonitor,
|
||||||
|
LoggingArgs,
|
||||||
|
MetricLogger,
|
||||||
|
get_num_params,
|
||||||
|
)
|
||||||
|
from bytelatent.optim import OptimArgs, build_optimizer
|
||||||
|
from bytelatent.probe import AutoProbeD
|
||||||
|
from bytelatent.profiling import ProfilerArgs, maybe_run_profiler
|
||||||
|
from bytelatent.stool import StoolArgs, launch_job
|
||||||
|
from bytelatent.transformer import (
|
||||||
|
LMTransformer,
|
||||||
|
LMTransformerArgs,
|
||||||
|
build_fsdp_grouping_plan,
|
||||||
|
get_no_recompute_ops,
|
||||||
|
get_num_flop_per_token,
|
||||||
|
tp_parallelize,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class TrainArgs(BaseModel):
|
||||||
|
name: str = "lingua"
|
||||||
|
dump_dir: str = ""
|
||||||
|
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
|
# Number of gradient accumulation steps
|
||||||
|
# Total batch size is batch_size*grad_acc_steps
|
||||||
|
grad_acc_steps: int = 1
|
||||||
|
|
||||||
|
gc_collect_freq: int = 1000
|
||||||
|
probe_freq: int | None = None
|
||||||
|
|
||||||
|
# Nb optimizer steps to take
|
||||||
|
steps: int = 1000
|
||||||
|
|
||||||
|
data: DataArgs
|
||||||
|
optim: OptimArgs
|
||||||
|
model: LMTransformerArgs
|
||||||
|
distributed: DistributedArgs
|
||||||
|
env: EnvironmentArgs
|
||||||
|
|
||||||
|
checkpoint: CheckpointArgs
|
||||||
|
profiling: ProfilerArgs
|
||||||
|
logging: LoggingArgs
|
||||||
|
|
||||||
|
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
|
||||||
|
async_eval_gpus: int | None = None
|
||||||
|
eval: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainState(Stateful):
|
||||||
|
step: int # Nb of steps taken by the optimizer
|
||||||
|
acc_step: int # Nb of accumulation steps done since last optimizer step
|
||||||
|
scheduler: lr_scheduler.LambdaLR
|
||||||
|
data_loader_state: PackTokensState
|
||||||
|
|
||||||
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"step": self.step,
|
||||||
|
"acc_step": self.acc_step,
|
||||||
|
"data_loader_state": self.data_loader_state,
|
||||||
|
"scheduler": self.scheduler.state_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.step = state_dict["step"]
|
||||||
|
self.acc_step = state_dict["acc_step"]
|
||||||
|
self.data_loader_state = PackTokensState(**state_dict["data_loader_state"])
|
||||||
|
self.scheduler.load_state_dict(state_dict["scheduler"])
|
||||||
|
|
||||||
|
|
||||||
|
def validate_train_args(args: TrainArgs, output_size: int):
|
||||||
|
if args.model.vocab_size < 0:
|
||||||
|
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
||||||
|
args.model.vocab_size = output_size
|
||||||
|
assert (
|
||||||
|
args.model.vocab_size == output_size
|
||||||
|
), "Vocab size should be the same as output size"
|
||||||
|
|
||||||
|
assert args.dump_dir, "Dump dir not set"
|
||||||
|
|
||||||
|
if args.checkpoint.path is None:
|
||||||
|
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
||||||
|
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
|
||||||
|
|
||||||
|
for source in args.data.sources:
|
||||||
|
data_path = os.path.join(args.data.root_dir, source)
|
||||||
|
assert os.path.exists(data_path), f"{data_path} doesn't exist"
|
||||||
|
|
||||||
|
if (
|
||||||
|
args.distributed.dp_replicate
|
||||||
|
* args.distributed.dp_shard
|
||||||
|
* args.distributed.tp_size
|
||||||
|
!= get_world_size()
|
||||||
|
):
|
||||||
|
assert get_world_size() % args.distributed.dp_shard == 0
|
||||||
|
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
||||||
|
|
||||||
|
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
||||||
|
args.distributed.dp_replicate = (
|
||||||
|
args.distributed.dp_replicate // args.distributed.tp_size
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
args.distributed.dp_replicate
|
||||||
|
* args.distributed.dp_shard
|
||||||
|
* args.distributed.tp_size
|
||||||
|
== get_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.distributed.fsdp_type == "no_shard":
|
||||||
|
assert (
|
||||||
|
args.distributed.dp_shard == 1
|
||||||
|
and args.distributed.dp_replicate == get_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
args.model.max_seqlen = args.data.seq_len
|
||||||
|
|
||||||
|
if args.distributed.tp_size == 1:
|
||||||
|
logger.warning(
|
||||||
|
"Tensor parallelism has not been tested for a while, use at your own risk"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.probe_freq != args.profiling.mem_steps
|
||||||
|
), "Don't profile during probe step"
|
||||||
|
assert (
|
||||||
|
args.probe_freq != args.profiling.profile_steps
|
||||||
|
), "Don't profile during probe step"
|
||||||
|
if args.logging.wandb is not None:
|
||||||
|
args.logging.wandb.name = args.name
|
||||||
|
|
||||||
|
if args.probe_freq is not None:
|
||||||
|
assert (
|
||||||
|
args.distributed.tp_size == 1
|
||||||
|
), "Probing not supported with tensor parallelism"
|
||||||
|
assert (
|
||||||
|
args.distributed.selective_activation_checkpointing is False
|
||||||
|
), "Probing not supported with selective activation checkpointing"
|
||||||
|
|
||||||
|
|
||||||
|
preemption_flag = dict(flag=False)
|
||||||
|
|
||||||
|
|
||||||
|
def set_preemption_flag(signum, frame):
|
||||||
|
logger.warning("Signal handler called with signal " + str(signum))
|
||||||
|
logger.warning("Preemption ! checkpointing asap and exiting.")
|
||||||
|
preemption_flag["flag"] = True
|
||||||
|
|
||||||
|
|
||||||
|
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
|
||||||
|
test = train_state.step % freq == 0
|
||||||
|
if acc_step is not None:
|
||||||
|
test = test and (train_state.acc_step == acc_step)
|
||||||
|
elif acc_freq is not None:
|
||||||
|
test = test and ((train_state.acc_step % acc_freq) == 0)
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def train(args: TrainArgs):
|
||||||
|
with ExitStack() as context_stack:
|
||||||
|
tokenizer_args = TokenizerArgs(
|
||||||
|
name=args.data.name,
|
||||||
|
init_kwargs=args.data.tokenizer.init_kwargs,
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_args.build()
|
||||||
|
validate_train_args(
|
||||||
|
args,
|
||||||
|
tokenizer.n_words,
|
||||||
|
)
|
||||||
|
if get_is_master():
|
||||||
|
os.makedirs(args.dump_dir, exist_ok=True)
|
||||||
|
dump_config(args, Path(args.dump_dir) / "config.yaml")
|
||||||
|
init_logger(Path(args.dump_dir) / "train.log")
|
||||||
|
init_signal_handler(set_preemption_flag) # For handling preemption signals.
|
||||||
|
setup_env(args.env)
|
||||||
|
setup_torch_distributed(args.distributed)
|
||||||
|
world_mesh = get_device_mesh(args.distributed)
|
||||||
|
logger.info(f"Starting job: {args.name}")
|
||||||
|
|
||||||
|
# build dataloader
|
||||||
|
# need dp world size and rank
|
||||||
|
dp_mesh = world_mesh["dp_replicate"]
|
||||||
|
dp_degree = dp_mesh.size()
|
||||||
|
dp_rank = dp_mesh.get_local_rank()
|
||||||
|
if args.distributed.dp_shard > 1:
|
||||||
|
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
|
||||||
|
dp_degree *= world_mesh["dp_shard"].size()
|
||||||
|
|
||||||
|
logger.info(f"Running on dp rank : {dp_rank}")
|
||||||
|
logger.info(f"Running on dp size : {dp_degree}")
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
logger.info("Building model")
|
||||||
|
|
||||||
|
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = LMTransformer(args.model)
|
||||||
|
logger.info("Model is built !")
|
||||||
|
|
||||||
|
model_param_count = get_num_params(model)
|
||||||
|
|
||||||
|
model = parallelize_model(
|
||||||
|
model,
|
||||||
|
world_mesh,
|
||||||
|
args.model,
|
||||||
|
args.distributed,
|
||||||
|
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
|
||||||
|
tp_parallelize=tp_parallelize,
|
||||||
|
no_recompute_ops=get_no_recompute_ops(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Once we shard the model on different gpus we can actually initialize the model
|
||||||
|
# First we create empty tensors of the correct shapes
|
||||||
|
model = model.to_empty(device="cuda")
|
||||||
|
# Then we init the model. Please make sure this function initializes *ALL* parameters
|
||||||
|
# and buffers, otherwise you will have random values in the unitialized tensors
|
||||||
|
# which will silently fail (give nan gradients for example)
|
||||||
|
|
||||||
|
if args.checkpoint.init_ckpt_path:
|
||||||
|
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
|
||||||
|
load_from_checkpoint(
|
||||||
|
args.checkpoint.init_ckpt_path, model, model_key="model"
|
||||||
|
) # Put model_key="" if its directly the model checkpoint
|
||||||
|
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
||||||
|
else:
|
||||||
|
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
||||||
|
torch.manual_seed(args.model.seed)
|
||||||
|
model.init_weights()
|
||||||
|
check_model_value_range(model, range=10.0, std=1.0)
|
||||||
|
|
||||||
|
# log model size
|
||||||
|
|
||||||
|
logger.info(f"Model size: {model_param_count:,} total parameters")
|
||||||
|
|
||||||
|
gpu_memory_monitor = GPUMemoryMonitor("cuda")
|
||||||
|
logger.info(
|
||||||
|
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
|
||||||
|
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
|
||||||
|
)
|
||||||
|
logger.info(f"GPU memory usage: {gpu_memory_monitor}")
|
||||||
|
|
||||||
|
# build optimizer after apply parallelisms to the model
|
||||||
|
optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
|
||||||
|
data_loader_state = init_dataloader_state_from_args(
|
||||||
|
args.data, dp_rank, dp_degree
|
||||||
|
)
|
||||||
|
|
||||||
|
train_state = TrainState(
|
||||||
|
step=0,
|
||||||
|
acc_step=0,
|
||||||
|
data_loader_state=data_loader_state,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
|
||||||
|
checkpoint.load(model, optimizer, train_state, world_mesh)
|
||||||
|
# Either load from latest checkpoint or start from scratch
|
||||||
|
if args.probe_freq is not None:
|
||||||
|
if get_is_master():
|
||||||
|
os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
probe = AutoProbeD(
|
||||||
|
model,
|
||||||
|
(
|
||||||
|
Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
|
||||||
|
if (dp_rank % 128 == 0)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
probe_mod = model._orig_mod if args.distributed.compile else model
|
||||||
|
|
||||||
|
gc.disable()
|
||||||
|
|
||||||
|
# train loop
|
||||||
|
model.train()
|
||||||
|
metric_logger = context_stack.enter_context(
|
||||||
|
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
|
||||||
|
)
|
||||||
|
data_loader = context_stack.enter_context(
|
||||||
|
build_dataloader_from_args(
|
||||||
|
args.data,
|
||||||
|
state=train_state.data_loader_state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
torch_profiler = context_stack.enter_context(
|
||||||
|
maybe_run_profiler(args.dump_dir, model, args.profiling)
|
||||||
|
)
|
||||||
|
|
||||||
|
nwords_since_last_log = 0
|
||||||
|
time_last_log = timer()
|
||||||
|
gc.collect()
|
||||||
|
while train_state.step < args.steps:
|
||||||
|
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
|
||||||
|
train_state.acc_step += 1
|
||||||
|
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
|
||||||
|
|
||||||
|
# get batch
|
||||||
|
curr_lr = float(optimizer.param_groups[0]["lr"])
|
||||||
|
data_load_start = timer()
|
||||||
|
batch, train_state.data_loader_state = next(data_loader)
|
||||||
|
batch = torch.tensor(
|
||||||
|
batch,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
|
||||||
|
if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
|
||||||
|
logger.info("garbage collection")
|
||||||
|
# we do garbage collection manually otherwise different processes
|
||||||
|
# run the GC at different times so they slow down the whole pipeline
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
input_ids = batch[:, :, 0].cuda()
|
||||||
|
labels = batch[:, :, 1].cuda()
|
||||||
|
data_load_time = round(timer() - data_load_start, 4)
|
||||||
|
nwords_since_last_log += input_ids.numel()
|
||||||
|
|
||||||
|
bsz, seqlen = labels.shape
|
||||||
|
|
||||||
|
# forward
|
||||||
|
start_timer = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_timer = torch.cuda.Event(enable_timing=True)
|
||||||
|
start_timer.record()
|
||||||
|
|
||||||
|
# This is an automatic probe that will compute statistics
|
||||||
|
# of all linears' inputs, weights and outputs
|
||||||
|
# along with attention logits and entropy
|
||||||
|
# both in forward and backward pass
|
||||||
|
if (args.probe_freq is not None) and every_n_steps(
|
||||||
|
train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps
|
||||||
|
):
|
||||||
|
# Here we do a fake forward and backward pass on a smaller
|
||||||
|
# batch size to avoid OOM
|
||||||
|
# This assumes the model has no stateful layers (batch norm..)
|
||||||
|
assert (
|
||||||
|
next(probe_mod.parameters()).grad is None
|
||||||
|
), "Can't probe model if grads are not reset"
|
||||||
|
|
||||||
|
with probe:
|
||||||
|
probe.metadata = {
|
||||||
|
"it": train_state.step,
|
||||||
|
"global_step": train_state.step,
|
||||||
|
"loop": "lingua",
|
||||||
|
}
|
||||||
|
# Non compiled model uses roughly 2x memory in our exps
|
||||||
|
# So we divide bsz by 2 or seqlen by 2
|
||||||
|
probe_bsz = max(1, bsz // 2)
|
||||||
|
probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2)
|
||||||
|
probe_loss = probe_mod(
|
||||||
|
input_ids[:probe_bsz, :probe_seq],
|
||||||
|
labels[:probe_bsz, :probe_seq],
|
||||||
|
)
|
||||||
|
probe_loss.backward()
|
||||||
|
# We zero grads to cancel this fake step
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
next(probe_mod.parameters()).grad is None
|
||||||
|
), "Probe model shouldn't have grads at this point"
|
||||||
|
loss = model(input_ids, labels)
|
||||||
|
|
||||||
|
# We scale loss with grad_acc_steps so the gradient is the same
|
||||||
|
# regardless of grad_acc_steps
|
||||||
|
loss = loss / args.grad_acc_steps
|
||||||
|
# backward on scaled loss to create scaled gradients
|
||||||
|
loss.backward()
|
||||||
|
# For logging we undo that scaling
|
||||||
|
loss = loss.detach() * args.grad_acc_steps
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_norm = (
|
||||||
|
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|
||||||
|
).item()
|
||||||
|
|
||||||
|
# optimizer step
|
||||||
|
if train_state.acc_step == 0:
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
train_state.step += 1
|
||||||
|
|
||||||
|
# updates the scale for next iteration
|
||||||
|
# training iteration complete
|
||||||
|
end_timer.record()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
|
||||||
|
|
||||||
|
# if profiler is active
|
||||||
|
if torch_profiler:
|
||||||
|
xformers.profiler.step()
|
||||||
|
|
||||||
|
# log metrics
|
||||||
|
if every_n_steps(
|
||||||
|
train_state,
|
||||||
|
args.logging.freq,
|
||||||
|
acc_step=None if args.logging.acc_freq else 0,
|
||||||
|
acc_freq=args.logging.acc_freq,
|
||||||
|
):
|
||||||
|
time_delta = timer() - time_last_log
|
||||||
|
wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
|
||||||
|
|
||||||
|
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
|
||||||
|
|
||||||
|
total_acc_steps = (
|
||||||
|
args.grad_acc_steps * train_state.step + train_state.acc_step
|
||||||
|
)
|
||||||
|
tokens_per_gpu = (
|
||||||
|
total_acc_steps * args.data.batch_size * args.data.seq_len
|
||||||
|
)
|
||||||
|
total_tokens = dp_degree * tokens_per_gpu
|
||||||
|
# This is an estimate and the correct values may change
|
||||||
|
# if you change the architecture
|
||||||
|
# Use xformer's analyze profile trace to get actual measurement
|
||||||
|
FLOPS = (
|
||||||
|
get_num_flop_per_token(
|
||||||
|
model_param_count - args.model.vocab_size * args.model.dim,
|
||||||
|
args.model.n_layers,
|
||||||
|
args.model.dim,
|
||||||
|
args.data.seq_len,
|
||||||
|
)
|
||||||
|
* wps
|
||||||
|
)
|
||||||
|
metrics = flatten_dict(
|
||||||
|
{
|
||||||
|
"global_step": train_state.step,
|
||||||
|
"acc_step": train_state.acc_step,
|
||||||
|
"speed": {
|
||||||
|
"wps": wps,
|
||||||
|
"FLOPS": FLOPS,
|
||||||
|
"curr_iter_time": curr_iter_time,
|
||||||
|
"data_load_time": data_load_time,
|
||||||
|
},
|
||||||
|
"optim": {
|
||||||
|
"grad_norm": grad_norm,
|
||||||
|
"lr": curr_lr,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
},
|
||||||
|
"memory": gpu_mem_stats._asdict(),
|
||||||
|
},
|
||||||
|
sep="/",
|
||||||
|
)
|
||||||
|
|
||||||
|
to_sync = {}
|
||||||
|
to_sync["loss/out"] = loss.item()
|
||||||
|
metrics.update(dist_mean_dict(to_sync))
|
||||||
|
|
||||||
|
if get_is_master():
|
||||||
|
metric_logger.log(metrics)
|
||||||
|
|
||||||
|
gpu_memory_monitor.reset_peak_stats()
|
||||||
|
nwords_since_last_log = 0
|
||||||
|
time_last_log = timer()
|
||||||
|
logger.info(
|
||||||
|
f"step: {train_state.step}"
|
||||||
|
f" acc: {train_state.acc_step}"
|
||||||
|
f" loss: {round(loss.item(),4):>7}"
|
||||||
|
f" grad: {grad_norm:.2e}"
|
||||||
|
f" flops: {FLOPS:.2e}"
|
||||||
|
f" wps: {wps:.2e}"
|
||||||
|
f" iter: {curr_iter_time:>7}"
|
||||||
|
f" data: {data_load_time:>5}"
|
||||||
|
f" lr: {curr_lr:.2e}"
|
||||||
|
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
|
||||||
|
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
||||||
|
)
|
||||||
|
|
||||||
|
saved = False
|
||||||
|
if every_n_steps(
|
||||||
|
train_state, args.checkpoint.dump.every, acc_step=0
|
||||||
|
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||||
|
saved = checkpoint.save(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_state,
|
||||||
|
args,
|
||||||
|
device_mesh=world_mesh,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.eval is not None and every_n_steps(
|
||||||
|
train_state, args.checkpoint.eval.every, acc_step=0
|
||||||
|
):
|
||||||
|
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
||||||
|
|
||||||
|
eval_args = dataclass_from_dict(EvalArgs, args.eval)
|
||||||
|
|
||||||
|
eval_args.global_step = train_state.step
|
||||||
|
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
|
||||||
|
eval_args.dump_dir = str(
|
||||||
|
os.path.join(
|
||||||
|
args.dump_dir,
|
||||||
|
"evals",
|
||||||
|
EVAL_FOLDER_NAME.format(train_state.step),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
eval_args.metric_log_dir = args.dump_dir
|
||||||
|
if args.async_eval_gpus is None:
|
||||||
|
launch_eval(eval_args)
|
||||||
|
elif get_is_master():
|
||||||
|
if wandb.run is not None and args.logging.wandb is not None:
|
||||||
|
eval_args.wandb = deepcopy(args.logging.wandb)
|
||||||
|
assert args.async_eval_gpus > 0
|
||||||
|
logger.info(f"Launching evals on {args.async_eval_gpus} gpus")
|
||||||
|
with clean_env():
|
||||||
|
launch_job(
|
||||||
|
StoolArgs(
|
||||||
|
asdict(eval_args),
|
||||||
|
script="apps.main.eval",
|
||||||
|
copy_code=False,
|
||||||
|
nodes=args.async_eval_gpus // 8,
|
||||||
|
qos="lowest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if preemption_flag["flag"]:
|
||||||
|
if not saved:
|
||||||
|
checkpoint.save(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_state,
|
||||||
|
args,
|
||||||
|
device_mesh=world_mesh,
|
||||||
|
)
|
||||||
|
requeue_slurm_job()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
if not saved:
|
||||||
|
checkpoint.save(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_state,
|
||||||
|
args,
|
||||||
|
device_mesh=world_mesh,
|
||||||
|
)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
||||||
|
This accepts arguments as a dot list
|
||||||
|
So if the dataclass looks like
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DummyArgs:
|
||||||
|
name: str
|
||||||
|
model: LMTransformerArgsgs
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMTransformerArgsgs:
|
||||||
|
dim: int
|
||||||
|
|
||||||
|
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
||||||
|
or just name=tictac for top level attributes.
|
||||||
|
|
||||||
|
The behavior here is as follows:
|
||||||
|
1. We instantiate TrainArgs with its default values
|
||||||
|
2. We override those default values with the ones in the provided config file
|
||||||
|
3. We override the result with the additional arguments provided through command line
|
||||||
|
|
||||||
|
For example, if the config is the following
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 128
|
||||||
|
n_layers: 4
|
||||||
|
|
||||||
|
and you call train.py with train.py model.dim=64
|
||||||
|
|
||||||
|
Then the final TrainArgs will have
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 64
|
||||||
|
n_layers: 4
|
||||||
|
|
||||||
|
Plus all the default values in TrainArgs dataclass.
|
||||||
|
"""
|
||||||
|
cli_args = OmegaConf.from_cli()
|
||||||
|
file_cfg = OmegaConf.load(cli_args.config)
|
||||||
|
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
||||||
|
del cli_args.config
|
||||||
|
|
||||||
|
default_cfg = OmegaConf.structured(TrainArgs())
|
||||||
|
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||||
|
cfg = OmegaConf.to_object(cfg)
|
||||||
|
|
||||||
|
train(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
BIN
blt-figure.jpg
Normal file
BIN
blt-figure.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 54 KiB |
BIN
blt-figure.pdf
Normal file
BIN
blt-figure.pdf
Normal file
Binary file not shown.
BIN
bytelatent/.DS_Store
vendored
Normal file
BIN
bytelatent/.DS_Store
vendored
Normal file
Binary file not shown.
3
bytelatent/__init__.py
Normal file
3
bytelatent/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
class ByteLatentError(Exception):
|
||||||
|
pass
|
199
bytelatent/args.py
Normal file
199
bytelatent/args.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.checkpoint import CheckpointArgs
|
||||||
|
from bytelatent.data.data_types import Batch
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import (
|
||||||
|
ArrowFileIterator,
|
||||||
|
find_and_sanitize_chunks,
|
||||||
|
)
|
||||||
|
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||||
|
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
||||||
|
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
||||||
|
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
||||||
|
from bytelatent.data.iterators.sampling_iterator import SamplingIterator
|
||||||
|
from bytelatent.data.iterators.sequence_iterator import (
|
||||||
|
SequenceIterator,
|
||||||
|
SequencePackingArgs,
|
||||||
|
)
|
||||||
|
from bytelatent.data.patcher import PatcherArgs
|
||||||
|
from bytelatent.distributed import DistributedArgs, EnvironmentArgs
|
||||||
|
from bytelatent.metrics import LoggingArgs
|
||||||
|
from bytelatent.model.blt import ByteLatentTransformerArgs
|
||||||
|
from bytelatent.optim import OptimArgs
|
||||||
|
from bytelatent.profiling import ProfilerArgs
|
||||||
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
|
||||||
|
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_data_to_rank(
|
||||||
|
*,
|
||||||
|
dataset_path: str,
|
||||||
|
preprocess_dir: str,
|
||||||
|
entropy_model_name: str | None,
|
||||||
|
arrow_batch_size: int,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
) -> ArrowFileIterator:
|
||||||
|
dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size)
|
||||||
|
n_workers_per_chunk = world_size // len(dataset_chunks)
|
||||||
|
rank_to_arrow_iterator_params = []
|
||||||
|
for chunk_path in dataset_chunks:
|
||||||
|
for worker_id in range(n_workers_per_chunk):
|
||||||
|
rank_to_arrow_iterator_params.append(
|
||||||
|
ArrowFileIterator(
|
||||||
|
file_path=chunk_path,
|
||||||
|
worker_id=worker_id,
|
||||||
|
num_workers=n_workers_per_chunk,
|
||||||
|
preprocess_dir=preprocess_dir,
|
||||||
|
dataset_files=None,
|
||||||
|
entropy_model_name=entropy_model_name,
|
||||||
|
arrow_batch_size=arrow_batch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return rank_to_arrow_iterator_params[rank]
|
||||||
|
|
||||||
|
|
||||||
|
class DataloaderArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
root_dir: str | None = None
|
||||||
|
sources: dict[str, float] = {}
|
||||||
|
batch_size: int = 2
|
||||||
|
seq_len: int = 2048
|
||||||
|
seed: int = 42
|
||||||
|
add_bos: bool = True
|
||||||
|
add_eos: bool = True
|
||||||
|
load_async: bool = True
|
||||||
|
prefetch_size: int = 64
|
||||||
|
preprocess_dir: str | None = None
|
||||||
|
dataset_files: list[str] | None = None
|
||||||
|
entropy_model_name: str | None = "transformer_100m"
|
||||||
|
arrow_batch_size: int = 100
|
||||||
|
buffer_size: int = 64
|
||||||
|
|
||||||
|
pad_to_max_length: bool = True
|
||||||
|
max_encoder_seq_length: int = 12288
|
||||||
|
enable_byte_ngrams: bool = False
|
||||||
|
|
||||||
|
tokenizer_args: TokenizerArgs = TokenizerArgs()
|
||||||
|
patcher_args: PatcherArgs = PatcherArgs()
|
||||||
|
|
||||||
|
def _create_sequence_iterators(
|
||||||
|
self, rank: int, world_size: int
|
||||||
|
) -> dict[str, SequenceIterator]:
|
||||||
|
sequence_packing_args = SequencePackingArgs(
|
||||||
|
output_seq_len=self.seq_len,
|
||||||
|
buffer_size=self.buffer_size,
|
||||||
|
)
|
||||||
|
source_to_sequence_iterator: dict[str, SequenceIterator] = {}
|
||||||
|
for dataset_path in self.sources:
|
||||||
|
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
||||||
|
arrow_iterator = distribute_data_to_rank(
|
||||||
|
dataset_path=os.path.join(self.root_dir, dataset_path),
|
||||||
|
preprocess_dir=self.preprocess_dir,
|
||||||
|
entropy_model_name=self.entropy_model_name,
|
||||||
|
arrow_batch_size=self.arrow_batch_size,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
looping_iterator = LoopingIterator(arrow_iterator)
|
||||||
|
preprocess_iterator = PreprocessIterator(
|
||||||
|
looping_iterator,
|
||||||
|
patcher_args=self.patcher_args,
|
||||||
|
tokenizer_args=self.tokenizer_args,
|
||||||
|
)
|
||||||
|
sequence_iterator = SequenceIterator(
|
||||||
|
preprocess_iterator,
|
||||||
|
sequence_packing_args=sequence_packing_args,
|
||||||
|
rng_state=shuffle_rng_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
source_to_sequence_iterator[dataset_path] = sequence_iterator
|
||||||
|
return source_to_sequence_iterator
|
||||||
|
|
||||||
|
def build_from_rank(
|
||||||
|
self, rank: int, world_size: int
|
||||||
|
) -> StatefulIterator[Batch, Any]:
|
||||||
|
source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size)
|
||||||
|
weight_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
||||||
|
sampling_iterator = SamplingIterator(
|
||||||
|
rng_state=weight_rng_state,
|
||||||
|
source_to_weight=self.sources,
|
||||||
|
source_to_iterator=source_to_sequence_iterators,
|
||||||
|
)
|
||||||
|
tokenizer = self.tokenizer_args.build()
|
||||||
|
packing_args = PackingArgs(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
seq_len=self.seq_len,
|
||||||
|
pad_id=tokenizer.boe_id,
|
||||||
|
max_length=self.max_encoder_seq_length,
|
||||||
|
pad_to_max_length=self.pad_to_max_length,
|
||||||
|
enable_byte_ngrams=self.enable_byte_ngrams,
|
||||||
|
)
|
||||||
|
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
||||||
|
mp_iterator = MultiprocessIterator(
|
||||||
|
packing_iterator, n_batches_to_prefetch=self.prefetch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return mp_iterator
|
||||||
|
|
||||||
|
|
||||||
|
class TrainArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
name: str = "lingua"
|
||||||
|
dump_dir: str = ""
|
||||||
|
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
|
# Number of gradient accumulation steps
|
||||||
|
# Total batch size is batch_size*grad_acc_steps
|
||||||
|
grad_acc_steps: int = 1
|
||||||
|
|
||||||
|
gc_collect_freq: int = 1000
|
||||||
|
probe_freq: int | None = None
|
||||||
|
|
||||||
|
# Nb optimizer steps to take
|
||||||
|
steps: int = 1000
|
||||||
|
|
||||||
|
data: DataloaderArgs = DataloaderArgs()
|
||||||
|
optim: OptimArgs = OptimArgs()
|
||||||
|
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
|
||||||
|
distributed: DistributedArgs = DistributedArgs()
|
||||||
|
env: EnvironmentArgs = EnvironmentArgs()
|
||||||
|
|
||||||
|
checkpoint: CheckpointArgs = CheckpointArgs()
|
||||||
|
profiling: ProfilerArgs = ProfilerArgs()
|
||||||
|
logging: LoggingArgs = LoggingArgs()
|
||||||
|
|
||||||
|
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
|
||||||
|
async_eval_gpus: int | None = None
|
||||||
|
eval: Any | None = None
|
||||||
|
eval_on_gpus: int | None = None
|
||||||
|
|
||||||
|
def dump_to_yaml_file(
|
||||||
|
self, path: str, log_config: bool = True, sort_keys: bool = True
|
||||||
|
):
|
||||||
|
model_dict = self.model_dump(mode="json")
|
||||||
|
yaml_str = yaml.dump(
|
||||||
|
model_dict,
|
||||||
|
allow_unicode=True,
|
||||||
|
sort_keys=sort_keys,
|
||||||
|
default_flow_style=False,
|
||||||
|
)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
if log_config:
|
||||||
|
logger.info("Using the following config for this run:")
|
||||||
|
logger.info(yaml_str)
|
||||||
|
f.write(yaml_str)
|
585
bytelatent/base_transformer.py
Normal file
585
bytelatent/base_transformer.py
Normal file
|
@ -0,0 +1,585 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.attention.flex_attention import (
|
||||||
|
BlockMask,
|
||||||
|
_mask_mod_signature,
|
||||||
|
flex_attention,
|
||||||
|
)
|
||||||
|
from xformers.ops import AttentionBias, fmha
|
||||||
|
|
||||||
|
from bytelatent import probe
|
||||||
|
|
||||||
|
flex_attention_comp = torch.compile(flex_attention)
|
||||||
|
|
||||||
|
|
||||||
|
class InitStdFactor(Enum):
|
||||||
|
DISABLED = "disabled" # Init std is divided by 1.0
|
||||||
|
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
|
||||||
|
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
|
||||||
|
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransformerArgs(BaseModel):
|
||||||
|
dim: int = 512
|
||||||
|
n_layers: int = 8
|
||||||
|
head_dim: Optional[int] = None
|
||||||
|
n_heads: Optional[int] = None
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
|
||||||
|
ffn_dim_multiplier: Optional[float] = None
|
||||||
|
|
||||||
|
multiple_of: int = 256
|
||||||
|
|
||||||
|
norm_eps: float = 1e-5
|
||||||
|
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
|
init_base_std: Optional[float] = None
|
||||||
|
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
||||||
|
|
||||||
|
max_seqlen: int = 1024
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy(pred, target, **kwargs):
|
||||||
|
return F.nll_loss(
|
||||||
|
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
|
||||||
|
target.flatten(end_dim=-1),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
|
||||||
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||||
|
assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
|
||||||
|
bs, slen, n_kv_heads, head_dim = x.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
return (
|
||||||
|
x[:, :, :, None, :]
|
||||||
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||||
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||||
|
"""
|
||||||
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||||
|
|
||||||
|
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
||||||
|
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||||
|
The returned tensor contains complex values in complex64 data type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Dimension of the frequency tensor.
|
||||||
|
end (int): End index for precomputing frequencies.
|
||||||
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
||||||
|
"""
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device)
|
||||||
|
freqs = torch.outer(t, freqs).float()
|
||||||
|
|
||||||
|
cos, sin = freqs.cos(), freqs.sin()
|
||||||
|
|
||||||
|
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
|
||||||
|
"""
|
||||||
|
Reshape frequency tensor for broadcasting it with another tensor.
|
||||||
|
|
||||||
|
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
||||||
|
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
||||||
|
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
||||||
|
seq_dim (int): Sequence dimension index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Reshaped frequency tensor.
|
||||||
|
"""
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= seq_dim < ndim
|
||||||
|
assert freqs_cis.shape == (
|
||||||
|
x.shape[seq_dim],
|
||||||
|
x.shape[-3],
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
|
||||||
|
shape = [
|
||||||
|
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
|
||||||
|
] + [2, 2]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
xq: torch.Tensor,
|
||||||
|
xk: torch.Tensor,
|
||||||
|
seq_dim: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
||||||
|
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
||||||
|
freqs_cis = reshape_for_broadcast(
|
||||||
|
freqs_cis, xq_, seq_dim
|
||||||
|
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
|
||||||
|
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
|
||||||
|
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
|
||||||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
def causal_mask(b, h, q_idx, kv_idx):
|
||||||
|
return q_idx >= kv_idx
|
||||||
|
|
||||||
|
|
||||||
|
def lengths_to_start_ids(lengths):
|
||||||
|
doc_start = lengths.cumsum(0)
|
||||||
|
doc_start = doc_start.roll(1)
|
||||||
|
doc_start[0] = 0
|
||||||
|
return doc_start
|
||||||
|
|
||||||
|
|
||||||
|
def lengths_to_local_ids(lengths):
|
||||||
|
assert lengths.ndim == 1
|
||||||
|
nb_seqs = lengths.size(0)
|
||||||
|
total_seqlen = lengths.sum()
|
||||||
|
# This gives the document id of each token
|
||||||
|
doc_id = torch.repeat_interleave(lengths)
|
||||||
|
# Compute document start for each document
|
||||||
|
doc_start = lengths_to_start_ids(lengths)
|
||||||
|
# Compute document start for each token
|
||||||
|
doc_start = doc_start[doc_id]
|
||||||
|
# Compute the position of each token within each document
|
||||||
|
tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start
|
||||||
|
|
||||||
|
return doc_id, tok_id
|
||||||
|
|
||||||
|
|
||||||
|
def generate_doc_mask_mod(
|
||||||
|
mask_mod: _mask_mod_signature,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
kv_lengths: Optional[torch.Tensor] = None,
|
||||||
|
) -> _mask_mod_signature:
|
||||||
|
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
|
||||||
|
format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_mod: The mask mod to apply to the documents
|
||||||
|
lengths: Lengths of each document
|
||||||
|
|
||||||
|
Note:
|
||||||
|
What is the sequence stacked format? When assembling batches of inputs, we
|
||||||
|
take multiple sequences and stack them together to form 1 large sequence. We then
|
||||||
|
use masking to ensure that the attention scores are only applied to tokens within
|
||||||
|
the same document.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
- Square mask
|
||||||
|
doc_mask lengths
|
||||||
|
a a b b b c c 2 3 2
|
||||||
|
a 1 0 0 0 0 0 0
|
||||||
|
a 1 1 0 0 0 0 0
|
||||||
|
b 0 0 1 0 0 0 0
|
||||||
|
b 0 0 1 1 0 0 0
|
||||||
|
b 0 0 1 1 1 0 0
|
||||||
|
c 0 0 0 0 0 1 0
|
||||||
|
c 0 0 0 0 0 1 1
|
||||||
|
|
||||||
|
"""
|
||||||
|
kv_lengths = kv_lengths if kv_lengths is not None else lengths
|
||||||
|
q_document_id, q_token_id = lengths_to_local_ids(lengths)
|
||||||
|
kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths)
|
||||||
|
q_max_idx = lengths.sum() - 1
|
||||||
|
kv_max_idx = kv_lengths.sum() - 1
|
||||||
|
|
||||||
|
def doc_mask_mod(b, h, q_idx, kv_idx):
|
||||||
|
q_idx_cap = torch.minimum(q_max_idx, q_idx)
|
||||||
|
kv_idx_cap = torch.minimum(kv_max_idx, kv_idx)
|
||||||
|
valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx)
|
||||||
|
same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap]
|
||||||
|
q_logical = q_token_id[q_idx_cap]
|
||||||
|
kv_logical = kv_token_id[kv_idx_cap]
|
||||||
|
inner_mask = mask_mod(b, h, q_logical, kv_logical)
|
||||||
|
return same_doc & inner_mask & valid_idx
|
||||||
|
|
||||||
|
return doc_mask_mod
|
||||||
|
|
||||||
|
|
||||||
|
# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
|
||||||
|
class RotaryEmbedding(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
RotaryEmbedding Module
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.theta = theta
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.max_seqlen = max_seqlen
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"freqs_cis",
|
||||||
|
precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.freqs_cis[...] = precompute_freqs_cis(
|
||||||
|
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
|
||||||
|
Args:
|
||||||
|
seqlen (int): Contiguous sequence length
|
||||||
|
tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
|
||||||
|
"""
|
||||||
|
test = (seqlen is not None) or (tok_idx is not None)
|
||||||
|
assert test, "Should provide atleast seqlen or tok_idx"
|
||||||
|
if tok_idx is not None:
|
||||||
|
return self.freqs_cis[tok_idx]
|
||||||
|
elif seqlen is not None:
|
||||||
|
return self.freqs_cis[0:seqlen]
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Initialize the RMSNorm normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The dimension of the input tensor.
|
||||||
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
eps (float): A small value added to the denominator for numerical stability.
|
||||||
|
weight (nn.Parameter): Learnable scaling parameter.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def _norm(self, x: torch.Tensor):
|
||||||
|
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = probe.log_stats(x, "resid")
|
||||||
|
output = self._norm(x.float())
|
||||||
|
return (output * self.weight.float()).type_as(x)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
torch.nn.init.ones_(self.weight) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
head_dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
rope_theta: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.heads_per_group = self.n_heads // self.n_kv_heads
|
||||||
|
|
||||||
|
self.wq = nn.Linear(
|
||||||
|
dim,
|
||||||
|
n_heads * head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.wk = nn.Linear(
|
||||||
|
dim,
|
||||||
|
n_kv_heads * head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.wv = nn.Linear(
|
||||||
|
dim,
|
||||||
|
n_kv_heads * head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.wo = nn.Linear(
|
||||||
|
n_heads * head_dim,
|
||||||
|
dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
freq_cis: torch.Tensor,
|
||||||
|
tok_idx: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
||||||
|
attn_impl: str = "sdpa",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# B S D
|
||||||
|
bsz, seq_len, dim = x.shape
|
||||||
|
xq = self.wq(x.view_as(x))
|
||||||
|
xk = self.wk(x.view_as(x))
|
||||||
|
xv = self.wv(x.view_as(x))
|
||||||
|
|
||||||
|
output_shape = xq.shape
|
||||||
|
# B S D -> B S H D
|
||||||
|
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
|
||||||
|
|
||||||
|
# This condition helps us be easily compatible
|
||||||
|
# with inference by adding a pluggable KVCache
|
||||||
|
if hasattr(self, "kv_cache"):
|
||||||
|
xk, xv = self.kv_cache.update(xk, xv, tok_idx)
|
||||||
|
|
||||||
|
xk = repeat_kv(xk, self.heads_per_group, dim=2)
|
||||||
|
xv = repeat_kv(xv, self.heads_per_group, dim=2)
|
||||||
|
|
||||||
|
if attn_impl == "flex_attention":
|
||||||
|
assert mask is None or isinstance(mask, BlockMask)
|
||||||
|
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
||||||
|
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
|
||||||
|
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
||||||
|
|
||||||
|
elif attn_impl == "fmha":
|
||||||
|
assert mask is None or isinstance(mask, AttentionBias)
|
||||||
|
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
|
||||||
|
# This uses B S H D instead of B H S D of pytorch
|
||||||
|
|
||||||
|
elif attn_impl == "sdpa":
|
||||||
|
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
||||||
|
assert mask is None or isinstance(mask, (str, torch.Tensor))
|
||||||
|
is_causal = (mask == "causal") if isinstance(mask, str) else False
|
||||||
|
mask = mask if isinstance(mask, torch.Tensor) else None
|
||||||
|
output = F.scaled_dot_product_attention(
|
||||||
|
xq,
|
||||||
|
xk,
|
||||||
|
xv,
|
||||||
|
is_causal=is_causal,
|
||||||
|
attn_mask=mask,
|
||||||
|
)
|
||||||
|
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Attention implementation {attn_impl} not supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.wo(output.reshape(output_shape))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def reset_parameters(self, init_std=None, factor=1.0):
|
||||||
|
init_std = init_std or (self.dim ** (-0.5))
|
||||||
|
|
||||||
|
for w in [self.wq, self.wk, self.wv]:
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
w.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.wo.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std / factor,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float],
|
||||||
|
mp_size: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
assert hidden_dim % mp_size == 0
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
|
||||||
|
self.w1 = nn.Linear(
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.w3 = nn.Linear(
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.w2 = nn.Linear(
|
||||||
|
hidden_dim,
|
||||||
|
dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# B S D
|
||||||
|
x1 = self.w1(x.view_as(x))
|
||||||
|
x3 = self.w3(x.view_as(x))
|
||||||
|
output = self.w2(F.silu(x1) * x3)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def reset_parameters(self, init_std=None, factor=1.0):
|
||||||
|
in_init_std = init_std or (self.dim ** (-0.5))
|
||||||
|
out_init_std = init_std or (self.hidden_dim ** (-0.5))
|
||||||
|
in_init_std = in_init_std
|
||||||
|
out_init_std = out_init_std / factor
|
||||||
|
for w in [self.w1, self.w3]:
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
w.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=in_init_std,
|
||||||
|
a=-3 * in_init_std,
|
||||||
|
b=3 * in_init_std,
|
||||||
|
)
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.w2.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=out_init_std,
|
||||||
|
a=-3 * out_init_std,
|
||||||
|
b=3 * out_init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: BaseTransformerArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (args.head_dim is not None) or (
|
||||||
|
args.n_heads is not None
|
||||||
|
), "Should specify at least head_dim or n_heads"
|
||||||
|
self.head_dim = args.head_dim or args.dim // args.n_heads
|
||||||
|
self.n_heads = args.n_heads or args.dim // args.head_dim
|
||||||
|
self.n_kv_heads = args.n_kv_heads or self.n_heads
|
||||||
|
|
||||||
|
assert args.n_heads % self.n_kv_heads == 0
|
||||||
|
assert args.dim % args.n_heads == 0
|
||||||
|
|
||||||
|
self.attention = Attention(
|
||||||
|
dim=args.dim,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
n_heads=self.n_heads,
|
||||||
|
n_kv_heads=self.n_kv_heads,
|
||||||
|
rope_theta=args.rope_theta,
|
||||||
|
)
|
||||||
|
self.feed_forward = FeedForward(
|
||||||
|
dim=args.dim,
|
||||||
|
hidden_dim=4 * args.dim,
|
||||||
|
multiple_of=args.multiple_of,
|
||||||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
|
)
|
||||||
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
freq_cis: torch.Tensor,
|
||||||
|
tok_idx: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
||||||
|
attn_impl: str = "sdpa",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
h = x + self.attention(
|
||||||
|
self.attention_norm(x),
|
||||||
|
freq_cis,
|
||||||
|
tok_idx=tok_idx,
|
||||||
|
mask=mask,
|
||||||
|
attn_impl=attn_impl,
|
||||||
|
)
|
||||||
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def init_weights(self, init_std=None, factor=1.0):
|
||||||
|
self.attention.reset_parameters(init_std, factor)
|
||||||
|
self.attention_norm.reset_parameters()
|
||||||
|
|
||||||
|
self.feed_forward.reset_parameters(init_std, factor)
|
||||||
|
self.ffn_norm.reset_parameters()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransformer(nn.Module):
|
||||||
|
def __init__(self, args: BaseTransformerArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = args.dim
|
||||||
|
self.init_base_std = args.init_base_std
|
||||||
|
self.init_std_factor = InitStdFactor(args.init_std_factor)
|
||||||
|
self.max_seqlen = args.max_seqlen
|
||||||
|
self.rope_embeddings = RotaryEmbedding(
|
||||||
|
theta=args.rope_theta,
|
||||||
|
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||||
|
max_seqlen=args.max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for _ in range(args.n_layers):
|
||||||
|
self.layers.append(TransformerBlock(args))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
h,
|
||||||
|
tok_idx: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
||||||
|
attn_impl: str = "sdpa",
|
||||||
|
):
|
||||||
|
|
||||||
|
freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
# Either use fixed base std or sqrt model dim
|
||||||
|
self.rope_embeddings.reset_parameters()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
self.reset_parameters()
|
||||||
|
for depth, layer in enumerate(self.layers):
|
||||||
|
factor = {
|
||||||
|
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
||||||
|
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
||||||
|
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
||||||
|
InitStdFactor.DISABLED: 1.0,
|
||||||
|
}[self.init_std_factor]
|
||||||
|
|
||||||
|
layer.init_weights(self.init_base_std, factor)
|
311
bytelatent/checkpoint.py
Normal file
311
bytelatent/checkpoint.py
Normal file
|
@ -0,0 +1,311 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.distributed.checkpoint as dcp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim.optimizer
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from torch.distributed._tensor import DeviceMesh
|
||||||
|
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
||||||
|
from torch.distributed.checkpoint.state_dict import (
|
||||||
|
get_model_state_dict,
|
||||||
|
get_state_dict,
|
||||||
|
set_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
from bytelatent.distributed import get_is_master
|
||||||
|
|
||||||
|
logger = logging.getLogger("CHECKPOINT")
|
||||||
|
|
||||||
|
FOLDER_NAME = "{:010d}"
|
||||||
|
RE_FOLDER = r"\d{10}"
|
||||||
|
|
||||||
|
RE_CKPT = r"__\d_\d\.distcp"
|
||||||
|
|
||||||
|
CONSOLIDATE_FOLDER = "consolidated"
|
||||||
|
CONSOLIDATE_NAME = "consolidated.pth"
|
||||||
|
|
||||||
|
CONFIG_NAME = "params.json"
|
||||||
|
TRAIN_STATE_NAME = "train_state_{:05d}.json"
|
||||||
|
RE_DIGITS = re.compile(r"\d+")
|
||||||
|
|
||||||
|
|
||||||
|
class SaveEvery(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
every: int = 1000
|
||||||
|
keep: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
dump: SaveEvery = SaveEvery()
|
||||||
|
eval: SaveEvery = SaveEvery()
|
||||||
|
path: str | None = None
|
||||||
|
init_ckpt_path: str | None = None
|
||||||
|
continue_training_from_init: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_key_step(name: str):
|
||||||
|
return int(re.findall(RE_DIGITS, name)[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def consolidate_checkpoints(ckpt_dir: str):
|
||||||
|
"""
|
||||||
|
Consolidates all FSDP checkpoints in a directory to a single file
|
||||||
|
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
ckpt_dir: str - path to the directory containing the checkpoints
|
||||||
|
|
||||||
|
Returns the path to the consolidated checkpoint
|
||||||
|
"""
|
||||||
|
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
|
||||||
|
if not (consolidate_path / CONSOLIDATE_NAME).exists():
|
||||||
|
consolidate_path.mkdir(exist_ok=True)
|
||||||
|
logger.info(f"Consolidating to: {str(consolidate_path)}")
|
||||||
|
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
|
||||||
|
(consolidate_path / CONFIG_NAME).write_text(
|
||||||
|
(Path(ckpt_dir) / CONFIG_NAME).read_text()
|
||||||
|
)
|
||||||
|
logger.info("Consolidated !")
|
||||||
|
return consolidate_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_checkpoint(
|
||||||
|
ckpt_dir: str,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
model_key: str = "model",
|
||||||
|
optim_key: str = "optim",
|
||||||
|
):
|
||||||
|
if not (Path(ckpt_dir) / ".metadata").exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
if optimizer is not None:
|
||||||
|
state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer)
|
||||||
|
else:
|
||||||
|
state_dict[model_key] = get_model_state_dict(model)
|
||||||
|
if model_key == "": # If only loading a model directly, the key should be empty
|
||||||
|
state_dict = state_dict.pop(model_key)
|
||||||
|
|
||||||
|
dcp.load(state_dict, checkpoint_id=ckpt_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointManager:
|
||||||
|
def __init__(self, args: CheckpointArgs):
|
||||||
|
self.path = args.path
|
||||||
|
self.dump_every = args.dump
|
||||||
|
self.eval_every = args.eval
|
||||||
|
self.init_ckpt_path = args.init_ckpt_path
|
||||||
|
self.continue_training_from_init = args.continue_training_from_init
|
||||||
|
|
||||||
|
assert os.path.exists(
|
||||||
|
self.path
|
||||||
|
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
|
||||||
|
|
||||||
|
self.existing_saves = self.get_existing_saves()
|
||||||
|
|
||||||
|
def get_existing_saves(self) -> List[Path]:
|
||||||
|
folders = [
|
||||||
|
p
|
||||||
|
for p in Path(self.path).iterdir()
|
||||||
|
if p.is_dir() and re.match(RE_FOLDER, p.name)
|
||||||
|
]
|
||||||
|
folders.sort(key=lambda p: _get_key_step(p.name))
|
||||||
|
return folders
|
||||||
|
|
||||||
|
def clean_up(self):
|
||||||
|
logger.info("Cleaning up checkpoints...")
|
||||||
|
dump_folders = []
|
||||||
|
eval_folders = []
|
||||||
|
other_folders = []
|
||||||
|
for p in self.existing_saves:
|
||||||
|
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
|
||||||
|
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
|
||||||
|
if is_dump:
|
||||||
|
dump_folders.append(p)
|
||||||
|
if is_eval:
|
||||||
|
eval_folders.append(p)
|
||||||
|
if not (is_dump or is_eval):
|
||||||
|
other_folders.append(p)
|
||||||
|
|
||||||
|
logger.info(f"Dump folders: {dump_folders}")
|
||||||
|
logger.info(f"Eval folders: {eval_folders}")
|
||||||
|
logger.info(f"Other folders: {other_folders}")
|
||||||
|
|
||||||
|
if self.dump_every.keep > 0:
|
||||||
|
dump_folders = dump_folders[-self.dump_every.keep :]
|
||||||
|
if self.eval_every.keep > 0:
|
||||||
|
eval_folders = eval_folders[-self.eval_every.keep :]
|
||||||
|
|
||||||
|
folder_to_keep = set(other_folders + dump_folders + eval_folders)
|
||||||
|
folder_to_remove = set(self.existing_saves) - folder_to_keep
|
||||||
|
|
||||||
|
logger.info(f"Removing folders: {folder_to_remove}")
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
for folder in folder_to_remove:
|
||||||
|
for file in folder.iterdir():
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
elif file.is_dir():
|
||||||
|
assert file.name in [CONSOLIDATE_FOLDER]
|
||||||
|
for f in file.iterdir():
|
||||||
|
f.unlink()
|
||||||
|
file.rmdir()
|
||||||
|
folder.rmdir()
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
self.existing_saves = list(folder_to_keep)
|
||||||
|
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
|
||||||
|
|
||||||
|
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
|
||||||
|
path = None
|
||||||
|
for p in reversed(self.existing_saves):
|
||||||
|
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
|
||||||
|
path = p
|
||||||
|
break
|
||||||
|
return path
|
||||||
|
|
||||||
|
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
|
||||||
|
folder = base_path / folder_name
|
||||||
|
if get_is_master():
|
||||||
|
folder.mkdir(parents=False, exist_ok=True)
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
return folder
|
||||||
|
|
||||||
|
def _get_dp_tp_mesh(
|
||||||
|
self, device_mesh: Optional[DeviceMesh] = None
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
dp_rank = 0
|
||||||
|
tp_rank = 0
|
||||||
|
if device_mesh is not None:
|
||||||
|
if "dp_replicate" in device_mesh.mesh_dim_names:
|
||||||
|
dp_rank = device_mesh.get_local_rank("dp_replicate")
|
||||||
|
if "dp_shard" in device_mesh.mesh_dim_names:
|
||||||
|
dp_rank = dp_rank * device_mesh[
|
||||||
|
"dp_replicate"
|
||||||
|
].size() + device_mesh.get_local_rank("dp_shard")
|
||||||
|
if "tp" in device_mesh.mesh_dim_names:
|
||||||
|
tp_rank = device_mesh.get_local_rank("tp")
|
||||||
|
return dp_rank, tp_rank
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_state_dict(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
):
|
||||||
|
model_sd, optim_sd = get_state_dict(model, optimizer)
|
||||||
|
return {"model": model_sd, "optim": optim_sd}
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_state,
|
||||||
|
config,
|
||||||
|
device_mesh: Optional[DeviceMesh] = None,
|
||||||
|
) -> bool:
|
||||||
|
|
||||||
|
# When creating directory check if only rank0 or is there other solution
|
||||||
|
path = Path(self.path)
|
||||||
|
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
|
||||||
|
logger.info(f"Saving to: {str(curr_save_dir)}")
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
logger.info("Saving...")
|
||||||
|
state_dict = self.get_state_dict(model, optimizer)
|
||||||
|
dcp.save(state_dict, checkpoint_id=curr_save_dir)
|
||||||
|
logger.info("State dict saved!")
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
if get_is_master():
|
||||||
|
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
|
||||||
|
|
||||||
|
# Add json dump here
|
||||||
|
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
||||||
|
if tp_rank == 0:
|
||||||
|
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
||||||
|
logger.info(
|
||||||
|
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
|
||||||
|
)
|
||||||
|
with open(curr_save_dir / train_state_name, "w") as f:
|
||||||
|
json.dump(train_state.state_dict(), f)
|
||||||
|
logger.info("Train state saved !")
|
||||||
|
|
||||||
|
self.existing_saves.append(curr_save_dir)
|
||||||
|
|
||||||
|
self.clean_up()
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
return True
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer,
|
||||||
|
train_state,
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
||||||
|
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
|
||||||
|
path = path or self.get_last_step_path(dp_rank=dp_rank)
|
||||||
|
# If none of those are available don't do anything
|
||||||
|
if path is None:
|
||||||
|
# If no checkpoints exist do nothing
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only load train state if it's provided, the files exist and we're not loading from init path
|
||||||
|
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
||||||
|
logger.info("Reloading train state")
|
||||||
|
with open(path / train_state_name, "r") as f:
|
||||||
|
train_state_dict = json.load(f)
|
||||||
|
train_state.load_state_dict(train_state_dict)
|
||||||
|
logger.info("Train state reloaded")
|
||||||
|
|
||||||
|
logger.info(f"Loading from: {str(path)}")
|
||||||
|
state_dict = self.get_state_dict(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
)
|
||||||
|
dcp.load(state_dict, checkpoint_id=path)
|
||||||
|
logger.info("State dict loaded.")
|
||||||
|
|
||||||
|
logger.info("Reloading model and optim")
|
||||||
|
|
||||||
|
set_state_dict(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
model_state_dict=state_dict["model"],
|
||||||
|
optim_state_dict=state_dict["optim"],
|
||||||
|
)
|
||||||
|
logger.info("Model and optim reloaded")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def instantiate_and_make_dir(cls, args: CheckpointArgs):
|
||||||
|
if get_is_master():
|
||||||
|
os.makedirs(args.path, exist_ok=True)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
return cls(args)
|
110
bytelatent/configs/debug.yaml
Normal file
110
bytelatent/configs/debug.yaml
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
# Template config, need to change dump_dir, data.root_dir and tokenizer.path
|
||||||
|
# Evals can be activated by uncommenting its config
|
||||||
|
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
|
||||||
|
|
||||||
|
dump_dir: /tmp/
|
||||||
|
name: "debug"
|
||||||
|
steps: 100_000
|
||||||
|
probe_freq: null
|
||||||
|
seed: 777
|
||||||
|
optim:
|
||||||
|
lr: 4e-04
|
||||||
|
warmup: 500
|
||||||
|
lr_min_ratio: 0.1
|
||||||
|
clip: 10.0
|
||||||
|
|
||||||
|
distributed:
|
||||||
|
fsdp_type: full_shard
|
||||||
|
compile: true
|
||||||
|
model_dtype: bf16
|
||||||
|
matmul_allow_tf32: false
|
||||||
|
selective_activation_checkpointing: false
|
||||||
|
tp_size: 1
|
||||||
|
|
||||||
|
model:
|
||||||
|
n_heads: 8
|
||||||
|
dim: 512
|
||||||
|
vocab_size: 260
|
||||||
|
dim_token: 256
|
||||||
|
patch_size: 6
|
||||||
|
tokenization_mode: "bytes"
|
||||||
|
patching_mode: "space"
|
||||||
|
tie_local_encoder_decoder_logits: false
|
||||||
|
data_loader_patching: true
|
||||||
|
max_encoder_seq_length: 12288
|
||||||
|
pad_to_max_length: true
|
||||||
|
patching_threshold: 3.1439168453216553
|
||||||
|
encoder_hash_byte_group_size: [4]
|
||||||
|
encoder_hash_byte_group_vocab: 50002
|
||||||
|
encoder_hash_byte_group_nb_functions: 3
|
||||||
|
encoder_enable_byte_ngrams: false
|
||||||
|
cross_attn_encoder: true # assuming cross_attention is true
|
||||||
|
cross_attn_decoder: true # assuming cross_attention is true
|
||||||
|
cross_attn_window_encoder: 512
|
||||||
|
cross_attn_window_decoder: 512
|
||||||
|
dim_local_encoder: 256
|
||||||
|
dim_local_decoder: 256
|
||||||
|
cross_attn_k: 8
|
||||||
|
cross_attn_nheads: 4
|
||||||
|
cross_attn_all_layers_decoder: true
|
||||||
|
cross_attn_all_layers_encoder: true
|
||||||
|
cross_attn_use_flex_attention: true
|
||||||
|
cross_attn_init_by_pooling: true
|
||||||
|
log_patch_lengths: true
|
||||||
|
non_linearity: "swiglu"
|
||||||
|
use_rope: true
|
||||||
|
recompute_fc1_out: false
|
||||||
|
recompute_fc3_out: false
|
||||||
|
recompute_attn: false
|
||||||
|
custom_bwd: false
|
||||||
|
layer_ckpt: "none"
|
||||||
|
efficient_attn: "sdpa"
|
||||||
|
patch_only_encoder: false
|
||||||
|
patch_only_decoder: false
|
||||||
|
use_local_encoder_transformer: true
|
||||||
|
init_use_gaussian: true
|
||||||
|
init_use_depth: "current"
|
||||||
|
attn_bias_type: "block_causal"
|
||||||
|
alpha_depth: "disabled"
|
||||||
|
max_length: 256
|
||||||
|
local_attention_window_len: 512
|
||||||
|
max_seqlen: 12288
|
||||||
|
downsampling_by_pooling: "max"
|
||||||
|
|
||||||
|
data:
|
||||||
|
root_dir: ???
|
||||||
|
sources:
|
||||||
|
dclm_baseline_1.0: 1.0
|
||||||
|
batch_size: 2
|
||||||
|
prefetch_size: 64
|
||||||
|
seq_len: 4096
|
||||||
|
load_async: true
|
||||||
|
preprocess_dir: ???
|
||||||
|
tokenizer_args:
|
||||||
|
name: blt
|
||||||
|
init_kwargs:
|
||||||
|
bpe_tokenizer_path: ???
|
||||||
|
|
||||||
|
profiling:
|
||||||
|
run: false
|
||||||
|
|
||||||
|
checkpoint:
|
||||||
|
dump:
|
||||||
|
every: 500
|
||||||
|
keep: 3
|
||||||
|
eval:
|
||||||
|
every: 1000
|
||||||
|
keep: -1
|
||||||
|
|
||||||
|
logging:
|
||||||
|
freq: 10
|
||||||
|
|
||||||
|
eval_on_gpus: 8
|
||||||
|
eval:
|
||||||
|
dataset_dir: /checkpoint/amaia/codegen/datasets/eval
|
||||||
|
tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
|
||||||
|
generator:
|
||||||
|
max_tokens: 65536
|
||||||
|
dtype: bf16
|
||||||
|
|
||||||
|
mp_size: 1
|
5
bytelatent/constants.py
Normal file
5
bytelatent/constants.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
BLT_DATA = Path(os.environ.get("BLT_DATA", "data"))
|
1
bytelatent/data/__init__.py
Normal file
1
bytelatent/data/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
115
bytelatent/data/data_types.py
Normal file
115
bytelatent/data/data_types.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Iterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class BltExample(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
sample_id: str
|
||||||
|
text: str
|
||||||
|
tokens: list[int] | None
|
||||||
|
entropies: list[float] | None
|
||||||
|
patch_lengths: list[int] | None
|
||||||
|
mask: list[bool] | None
|
||||||
|
|
||||||
|
|
||||||
|
class MultiChoiceState(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
root_dir: str
|
||||||
|
sources: dict[str, float]
|
||||||
|
source_to_state: dict[str, Any]
|
||||||
|
rng_state: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class PrefetchState(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
seq_idx: int
|
||||||
|
rng_state: dict[str, Any]
|
||||||
|
prefetch_size: int
|
||||||
|
batch_size: int
|
||||||
|
|
||||||
|
|
||||||
|
class BltPackTokensState(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
start_token: int
|
||||||
|
output_seq_len: int
|
||||||
|
n_views: int = 2
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoaderState(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
multi_choice_state: MultiChoiceState
|
||||||
|
pack_tokens_state: BltPackTokensState
|
||||||
|
prefetch_state: PrefetchState
|
||||||
|
|
||||||
|
|
||||||
|
BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
|
||||||
|
|
||||||
|
|
||||||
|
class BltSequence(BaseModel):
|
||||||
|
tokens: list[int]
|
||||||
|
mask: list[bool]
|
||||||
|
patch_lengths: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Batch:
|
||||||
|
x: np.ndarray
|
||||||
|
y: np.ndarray
|
||||||
|
mask: np.ndarray | None = None
|
||||||
|
patch_lengths: np.ndarray | None = None
|
||||||
|
ngram_ids: np.ndarray | None = None
|
||||||
|
is_final: bool = False
|
||||||
|
|
||||||
|
def to_python_dict(self) -> dict:
|
||||||
|
x = self.x.tolist()
|
||||||
|
y = self.y.tolist()
|
||||||
|
if self.mask is None:
|
||||||
|
mask = None
|
||||||
|
else:
|
||||||
|
mask = self.mask.tolist()
|
||||||
|
if self.patch_lengths is None:
|
||||||
|
patch_lengths = None
|
||||||
|
else:
|
||||||
|
patch_lengths = self.patch_lengths.tolist()
|
||||||
|
if self.ngram_ids is None:
|
||||||
|
ngram_ids = None
|
||||||
|
else:
|
||||||
|
ngram_ids = self.ngram_ids.tolist()
|
||||||
|
return {
|
||||||
|
"x": x,
|
||||||
|
"y": y,
|
||||||
|
"mask": mask,
|
||||||
|
"patch_lengths": patch_lengths,
|
||||||
|
"ngram_ids": ngram_ids,
|
||||||
|
"is_final": self.is_final,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_python_dict(cls, data: dict) -> "Batch":
|
||||||
|
x = np.array(data["x"])
|
||||||
|
y = np.array(data["y"])
|
||||||
|
if data["mask"] is None:
|
||||||
|
mask = None
|
||||||
|
else:
|
||||||
|
mask = np.array(data["mask"])
|
||||||
|
if data["patch_lengths"] is None:
|
||||||
|
patch_lengths = None
|
||||||
|
else:
|
||||||
|
patch_lengths = np.array(data["patch_lengths"])
|
||||||
|
if data["ngram_ids"] is None:
|
||||||
|
ngram_ids = None
|
||||||
|
else:
|
||||||
|
ngram_ids = np.array(data["ngram_ids"])
|
||||||
|
return Batch(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
mask=mask,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
ngram_ids=ngram_ids,
|
||||||
|
is_final=data["is_final"],
|
||||||
|
)
|
1
bytelatent/data/iterators/__init__.py
Normal file
1
bytelatent/data/iterators/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
23
bytelatent/data/iterators/abstract_iterator.py
Normal file
23
bytelatent/data/iterators/abstract_iterator.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import abc
|
||||||
|
from typing import Any, Generator, Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
C = TypeVar("C")
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulIterator(Generic[T, C], abc.ABC):
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_state(self) -> C:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def create_iter(self) -> Generator[T, Any, None]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IteratorState(Generic[C]):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def build(self) -> StatefulIterator[T, C]:
|
||||||
|
pass
|
216
bytelatent/data/iterators/arrow_iterator.py
Normal file
216
bytelatent/data/iterators/arrow_iterator.py
Normal file
|
@ -0,0 +1,216 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import re
|
||||||
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
# pyarrow needs the initialization from this import
|
||||||
|
import pyarrow.dataset # pyright: ignore
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent import ByteLatentError
|
||||||
|
from bytelatent.data.data_types import BltExample
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ArrowFileIteratorState(BaseModel, IteratorState):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
file_path: str | None
|
||||||
|
row_num: int
|
||||||
|
num_workers: int
|
||||||
|
worker_id: int
|
||||||
|
preprocess_dir: str | None
|
||||||
|
dataset_files: list[str] | None
|
||||||
|
entropy_model_name: str | None
|
||||||
|
arrow_batch_size: int = 100
|
||||||
|
|
||||||
|
def build(self) -> "ArrowFileIterator":
|
||||||
|
arrow_file = ArrowFileIterator(
|
||||||
|
file_path=self.file_path,
|
||||||
|
worker_id=self.worker_id,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
preprocess_dir=self.preprocess_dir,
|
||||||
|
entropy_model_name=self.entropy_model_name,
|
||||||
|
arrow_batch_size=self.arrow_batch_size,
|
||||||
|
dataset_files=self.dataset_files,
|
||||||
|
)
|
||||||
|
if self.row_num != 0:
|
||||||
|
arrow_file._set_row_num(self.row_num)
|
||||||
|
return arrow_file
|
||||||
|
|
||||||
|
|
||||||
|
def shard_sort_key(file: str | Path):
|
||||||
|
match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
|
||||||
|
shard_number = int(match.group(1))
|
||||||
|
return shard_number
|
||||||
|
|
||||||
|
|
||||||
|
class ArrowFileIterator(StatefulIterator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_path: str | None,
|
||||||
|
worker_id: int,
|
||||||
|
num_workers: int,
|
||||||
|
preprocess_dir: str | None,
|
||||||
|
entropy_model_name: str | None,
|
||||||
|
arrow_batch_size: int,
|
||||||
|
dataset_files: list[str] | None = None,
|
||||||
|
):
|
||||||
|
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
||||||
|
if file_path is None and dataset_files is None:
|
||||||
|
raise ByteLatentError("file_path and dataset_files cannot both be None")
|
||||||
|
self.row_num = 0
|
||||||
|
self.iter_id = 0
|
||||||
|
self.batch_iterator = None
|
||||||
|
self.batch_to_consume = None
|
||||||
|
self.dataset = None
|
||||||
|
self.file_path = file_path
|
||||||
|
self.worker_id = worker_id
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.preprocess_dir = preprocess_dir
|
||||||
|
self.entropy_model_name = entropy_model_name
|
||||||
|
self.arrow_batch_size = arrow_batch_size
|
||||||
|
if dataset_files is None:
|
||||||
|
# Prepare arrow shards
|
||||||
|
jsonl_file = Path(file_path)
|
||||||
|
parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
|
||||||
|
assert parts is not None
|
||||||
|
dataset = parts.group(1)
|
||||||
|
data_dir = Path(preprocess_dir) / dataset / entropy_model_name
|
||||||
|
shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
|
||||||
|
for s in shard_files:
|
||||||
|
if not (data_dir / f"{s.name}.complete").exists():
|
||||||
|
raise ValueError(f"Missing .complete for input file: {s}")
|
||||||
|
|
||||||
|
shard_files = sorted(shard_files, key=shard_sort_key)
|
||||||
|
if len(shard_files) == 0:
|
||||||
|
raise ByteLatentError(
|
||||||
|
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
||||||
|
)
|
||||||
|
self.dataset_files = [str(f) for f in shard_files]
|
||||||
|
else:
|
||||||
|
self.preprocess_dir = None
|
||||||
|
self.dataset_files = dataset_files
|
||||||
|
|
||||||
|
def get_state(self) -> ArrowFileIteratorState:
|
||||||
|
return ArrowFileIteratorState(
|
||||||
|
file_path=self.file_path,
|
||||||
|
row_num=self.row_num,
|
||||||
|
worker_id=self.worker_id,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
preprocess_dir=self.preprocess_dir,
|
||||||
|
entropy_model_name=self.entropy_model_name,
|
||||||
|
arrow_batch_size=self.arrow_batch_size,
|
||||||
|
dataset_files=self.dataset_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_iter(
|
||||||
|
self,
|
||||||
|
) -> Generator[BltExample, Any, None]:
|
||||||
|
if self.dataset is None:
|
||||||
|
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
|
||||||
|
self.batch_iterator = self.dataset.to_batches(
|
||||||
|
batch_size=self.arrow_batch_size
|
||||||
|
)
|
||||||
|
self.iter_id += 1
|
||||||
|
if self.batch_to_consume is not None:
|
||||||
|
batch_columns: dict[str, list] = self.batch_to_consume
|
||||||
|
self.batch_to_consume = None
|
||||||
|
sample_ids = batch_columns["sample_id"]
|
||||||
|
texts = batch_columns["text"]
|
||||||
|
entropies = batch_columns["entropies"]
|
||||||
|
for i in range(len(sample_ids)):
|
||||||
|
out = BltExample(
|
||||||
|
sample_id=sample_ids[i],
|
||||||
|
entropies=entropies[i],
|
||||||
|
text=texts[i],
|
||||||
|
tokens=None,
|
||||||
|
mask=None,
|
||||||
|
patch_lengths=None,
|
||||||
|
)
|
||||||
|
self.row_num += 1
|
||||||
|
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
||||||
|
yield out
|
||||||
|
|
||||||
|
for batch in self.batch_iterator:
|
||||||
|
batch_columns = batch.to_pydict()
|
||||||
|
sample_ids = batch_columns["sample_id"]
|
||||||
|
texts = batch_columns["text"]
|
||||||
|
entropies = batch_columns["entropies"]
|
||||||
|
for i in range(len(sample_ids)):
|
||||||
|
out = BltExample(
|
||||||
|
sample_id=sample_ids[i],
|
||||||
|
entropies=entropies[i],
|
||||||
|
text=texts[i],
|
||||||
|
tokens=None,
|
||||||
|
mask=None,
|
||||||
|
patch_lengths=None,
|
||||||
|
)
|
||||||
|
self.row_num += 1
|
||||||
|
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
||||||
|
yield out
|
||||||
|
|
||||||
|
def _set_row_num(self, target_row_num: int):
|
||||||
|
logger.info(
|
||||||
|
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||||
|
)
|
||||||
|
if target_row_num is None or target_row_num == 0:
|
||||||
|
self.row_num = 0
|
||||||
|
self.dataset = None
|
||||||
|
self.batch_iterator = None
|
||||||
|
self.batch_to_consume = None
|
||||||
|
else:
|
||||||
|
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
|
||||||
|
self.batch_iterator = self.dataset.to_batches(
|
||||||
|
batch_size=self.arrow_batch_size
|
||||||
|
)
|
||||||
|
curr_remaining = target_row_num
|
||||||
|
for batch in self.batch_iterator:
|
||||||
|
if len(batch) > curr_remaining:
|
||||||
|
batch_columns: dict[str, list] = batch.to_pydict()
|
||||||
|
batch_columns["sample_id"] = batch_columns["sample_id"][
|
||||||
|
curr_remaining:
|
||||||
|
]
|
||||||
|
batch_columns["entropies"] = batch_columns["entropies"][
|
||||||
|
curr_remaining:
|
||||||
|
]
|
||||||
|
batch_columns["text"] = batch_columns["text"][curr_remaining:]
|
||||||
|
self.batch_to_consume = batch_columns
|
||||||
|
break
|
||||||
|
elif len(batch) == curr_remaining:
|
||||||
|
# We are exactly at the end of the batch,
|
||||||
|
# so the next batch is the right spot
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
curr_remaining -= len(batch)
|
||||||
|
self.row_num = target_row_num
|
||||||
|
logger.info(
|
||||||
|
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def find_and_sanitize_chunks(
|
||||||
|
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
|
||||||
|
):
|
||||||
|
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
|
||||||
|
n_chunks = len(dataset_chunks)
|
||||||
|
|
||||||
|
if n_chunks > world_size:
|
||||||
|
n_discard = n_chunks - world_size
|
||||||
|
dataset_chunks = dataset_chunks[:world_size]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
world_size % n_chunks == 0
|
||||||
|
), "World size should be a multiple of number of chunks"
|
||||||
|
|
||||||
|
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||||
|
|
||||||
|
return dataset_chunks
|
36
bytelatent/data/iterators/looping_iterator.py
Normal file
36
bytelatent/data/iterators/looping_iterator.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import (
|
||||||
|
ArrowFileIterator,
|
||||||
|
ArrowFileIteratorState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoopingIteratorState(BaseModel, IteratorState):
|
||||||
|
file_iterator_state: ArrowFileIteratorState
|
||||||
|
epoch: int
|
||||||
|
|
||||||
|
def build(self) -> "LoopingIterator":
|
||||||
|
return LoopingIterator(
|
||||||
|
file_iterator=self.file_iterator_state.build(),
|
||||||
|
epoch=self.epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoopingIterator(StatefulIterator):
|
||||||
|
def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1):
|
||||||
|
self.file_iterator = file_iterator
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return LoopingIteratorState(
|
||||||
|
file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
while True:
|
||||||
|
self.epoch += 1
|
||||||
|
iterator = self.file_iterator.create_iter()
|
||||||
|
yield from iterator
|
243
bytelatent/data/iterators/multiprocess_iterator.py
Normal file
243
bytelatent/data/iterators/multiprocess_iterator.py
Normal file
|
@ -0,0 +1,243 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
from multiprocessing.synchronize import Event as EventClass
|
||||||
|
from queue import Empty, Full
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.data.data_types import Batch
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class MultiprocessIteratorState(BaseModel, IteratorState):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
base_iterator_state: PackingIteratorState
|
||||||
|
n_batches_to_prefetch: int
|
||||||
|
serialized_prefetch_buffer: str
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
base_iterator = self.base_iterator_state.build()
|
||||||
|
data = json.loads(self.serialized_prefetch_buffer)
|
||||||
|
prefetch_buffer = [Batch.from_python_dict(item) for item in data]
|
||||||
|
return MultiprocessIterator(
|
||||||
|
base_iterator,
|
||||||
|
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||||
|
prefetch_buffer=prefetch_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def start_work_from_state(
|
||||||
|
batch_queue: mp.Queue,
|
||||||
|
state_queue: mp.Queue,
|
||||||
|
stop_event: EventClass,
|
||||||
|
state_dumped_event: EventClass,
|
||||||
|
state: IteratorState,
|
||||||
|
):
|
||||||
|
logging.info("Worker thread: Starting base_iterator work")
|
||||||
|
stateful_iterator = state.build()
|
||||||
|
iterator = stateful_iterator.create_iter()
|
||||||
|
for item in iterator:
|
||||||
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
|
||||||
|
batch_queue.put(item, timeout=0.1)
|
||||||
|
# On success, stop trying
|
||||||
|
break
|
||||||
|
except Full:
|
||||||
|
pass
|
||||||
|
if stop_event.is_set():
|
||||||
|
# Signal the end of output, this ensures that even if the queue takes a while to
|
||||||
|
# buffer, that the main thread receives everything (and tosses this fake batch)
|
||||||
|
logging.info(
|
||||||
|
"Worker thread: Stop event detected, outputting is_final=True batch"
|
||||||
|
)
|
||||||
|
batch_queue.put(
|
||||||
|
Batch(
|
||||||
|
x=np.zeros((1, 1)),
|
||||||
|
y=np.zeros((1, 1)),
|
||||||
|
is_final=True,
|
||||||
|
mask=None,
|
||||||
|
patch_lengths=None,
|
||||||
|
ngram_ids=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info("Worker thread: outputting state")
|
||||||
|
state_queue.put(iterator.get_state(), timeout=1)
|
||||||
|
logging.info("Worker thread: state dump complete")
|
||||||
|
state_dumped_event.set()
|
||||||
|
logging.info("Worker thread: set state_dump_event")
|
||||||
|
except Full:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempted to dump state into the state queue, but it was full"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiprocessIterator(StatefulIterator):
|
||||||
|
"""
|
||||||
|
Design sketch of the multiprocess iterator:
|
||||||
|
|
||||||
|
Given the base_iterator, the only thing we do with this is call get_state()
|
||||||
|
so that we can pass that through to the background worker process.
|
||||||
|
|
||||||
|
The background process will receive this, rebuild the iterator, then start yielding from it.
|
||||||
|
|
||||||
|
However, in order to implement MultiprocessIterator.get_state(), we need to be able to accurately get
|
||||||
|
(1) the state of the iterator in the worker process
|
||||||
|
(2) the currently buffered items in the Queue
|
||||||
|
|
||||||
|
To do this, we use:
|
||||||
|
- batch_queue: This is the prefetch buffer the worker yields to and the main loop yields from
|
||||||
|
- state_queue: This size 1 queue will be how the worker sends the iterator state once it has halted iterating.
|
||||||
|
It must hold the state in addition to the last batch, if the queue was full at the time the stop event is sent.
|
||||||
|
- stop_iterating_event: Once this is issued from the main loop, the worker will stop iterating and enter cleanup.
|
||||||
|
During cleanup, the iterator will send the state of the current iterator to the main loop,
|
||||||
|
in addition to possibly the last batch if the batch_queue was full at the time
|
||||||
|
- state_dumped_event: When the main loop issues the stop_iterating_event, it will wait until the state_dumped_event to attempt
|
||||||
|
to get state from the state_queue. It must do this since the worker may take some time to create and send the state.
|
||||||
|
Once received by the main loop, the main loop can safely store the Queue (plus maybe the last batch) as the prefetch buffer,
|
||||||
|
get the worker iterator's state, and terminate the background process + delete associated objects.
|
||||||
|
|
||||||
|
At this point, calling create_iter() again will bootstrap everything from the stored state and the old iterator will throw an error
|
||||||
|
since it will not iterate anymore (so the caller must call create_iter() again to get a python iterator).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_iterator: StatefulIterator,
|
||||||
|
*,
|
||||||
|
n_batches_to_prefetch: int,
|
||||||
|
prefetch_buffer: list | None = None
|
||||||
|
):
|
||||||
|
self.base_iterator = base_iterator
|
||||||
|
self.n_batches_to_prefetch = n_batches_to_prefetch
|
||||||
|
if prefetch_buffer is None:
|
||||||
|
prefetch_buffer = []
|
||||||
|
self.prefetch_buffer = prefetch_buffer
|
||||||
|
self.batch_queue = None
|
||||||
|
self.state_queue = None
|
||||||
|
self.producer = None
|
||||||
|
self.stop_iterating_event = None
|
||||||
|
self.state_dumped_event = None
|
||||||
|
|
||||||
|
def get_state(self) -> MultiprocessIteratorState:
|
||||||
|
"""
|
||||||
|
This is slightly unusual in effectively destroying the current iterator, its necessary
|
||||||
|
to halt the background process and allow it to write the state to the main loop
|
||||||
|
in order to not lose data
|
||||||
|
"""
|
||||||
|
if self.producer is None:
|
||||||
|
serialized_prefetch_buffer = json.dumps(
|
||||||
|
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||||
|
)
|
||||||
|
return MultiprocessIteratorState(
|
||||||
|
base_iterator_state=self.base_iterator.get_state(),
|
||||||
|
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||||
|
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Main thread: Sending stop iteration event")
|
||||||
|
self.stop_iterating_event.set()
|
||||||
|
logging.info("Main thread: Waiting for state_dumped event")
|
||||||
|
self.state_dumped_event.wait()
|
||||||
|
self.prefetch_buffer = []
|
||||||
|
final_batch_received = False
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
batch = self.batch_queue.get(timeout=1)
|
||||||
|
if batch.is_final:
|
||||||
|
final_batch_received = True
|
||||||
|
break
|
||||||
|
self.prefetch_buffer.append(batch)
|
||||||
|
except Empty:
|
||||||
|
logging.warning("Main thread: batch_queue is abnormally empty")
|
||||||
|
assert final_batch_received
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_iterator_state = self.state_queue.get(timeout=1)
|
||||||
|
assert isinstance(base_iterator_state, IteratorState)
|
||||||
|
except Empty:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempted to get the state, but it was unexpectantly missing"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.base_iterator = base_iterator_state.build()
|
||||||
|
self.producer.close()
|
||||||
|
self.producer = None
|
||||||
|
self.batch_queue = None
|
||||||
|
self.state_queue = None
|
||||||
|
self.stop_iterating_event = None
|
||||||
|
self.state_dumped_event = None
|
||||||
|
|
||||||
|
return MultiprocessIteratorState(
|
||||||
|
base_iterator_state=self.base_iterator.get_state(),
|
||||||
|
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||||
|
serialized_prefetch_buffer=json.dumps(
|
||||||
|
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
logging.info("Main thread: Creating MP iterator")
|
||||||
|
# First yield from the stored prefetch buffer.
|
||||||
|
if self.prefetch_buffer is not None:
|
||||||
|
while len(self.prefetch_buffer) > 0:
|
||||||
|
item = self.prefetch_buffer.pop(0)
|
||||||
|
yield item
|
||||||
|
self.prefetch_buffer = None
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.producer is None
|
||||||
|
), "Cannot create two parallel iterators at once, call get_state() then remake to have two."
|
||||||
|
|
||||||
|
# using mp context manager avoids excessive CPU loading
|
||||||
|
ctx = mp.get_context("forkserver")
|
||||||
|
self.batch_queue = ctx.Manager().Queue(maxsize=self.n_batches_to_prefetch)
|
||||||
|
|
||||||
|
# We should only ever one state, which is output at the detection of a stop event
|
||||||
|
self.state_queue = ctx.Manager().Queue(maxsize=1)
|
||||||
|
|
||||||
|
self.stop_iterating_event = ctx.Event()
|
||||||
|
self.state_dumped_event = ctx.Event()
|
||||||
|
|
||||||
|
self.producer = mp.Process(
|
||||||
|
name="blt_data_loader",
|
||||||
|
target=start_work_from_state,
|
||||||
|
args=(
|
||||||
|
self.batch_queue,
|
||||||
|
self.state_queue,
|
||||||
|
self.stop_iterating_event,
|
||||||
|
self.state_dumped_event,
|
||||||
|
self.base_iterator.get_state(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info("Async dataloader started")
|
||||||
|
self.producer.start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if self.producer.exitcode is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Data loader quit unexpectedly, real error has been raised previously"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
batch = self.batch_queue.get(timeout=0.1)
|
||||||
|
assert isinstance(batch, Batch)
|
||||||
|
assert (
|
||||||
|
not batch.is_final
|
||||||
|
), "is_final should only be used during get_state() being called"
|
||||||
|
yield batch
|
||||||
|
except Empty:
|
||||||
|
pass
|
||||||
|
if self.producer is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempted to call this iterator after calling get_state(). You must call create_iter() to make a new iterator instead."
|
||||||
|
)
|
226
bytelatent/data/iterators/packing_iterator.py
Normal file
226
bytelatent/data/iterators/packing_iterator.py
Normal file
|
@ -0,0 +1,226 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.data.data_types import Batch, BltSequence
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
|
||||||
|
|
||||||
|
|
||||||
|
class PackingArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
batch_size: int
|
||||||
|
seq_len: int
|
||||||
|
pad_id: int
|
||||||
|
max_length: int | None
|
||||||
|
pad_to_max_length: bool
|
||||||
|
enable_byte_ngrams: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PackingIteratorState(BaseModel, IteratorState):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
sequence_iterator_state: SamplingIteratorState
|
||||||
|
packing_args: PackingArgs
|
||||||
|
|
||||||
|
def build(self) -> "PackingIterator":
|
||||||
|
return PackingIterator(
|
||||||
|
sequence_iterator=self.sequence_iterator_state.build(),
|
||||||
|
packing_args=self.packing_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]):
|
||||||
|
assert len(mask_seqs) == bs
|
||||||
|
lens = [len(m) for m in mask_seqs]
|
||||||
|
if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
|
||||||
|
return None
|
||||||
|
assert slen == max(lens) - 1
|
||||||
|
mask = np.zeros((bs, slen), dtype=bool)
|
||||||
|
for i, m in enumerate(mask_seqs):
|
||||||
|
if m is None:
|
||||||
|
print(
|
||||||
|
"Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function."
|
||||||
|
)
|
||||||
|
raise NotImplementedError
|
||||||
|
mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:]
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_batch(
|
||||||
|
batch: Batch,
|
||||||
|
max_length: int,
|
||||||
|
pad_id: int,
|
||||||
|
pad_to_max_length: bool = False,
|
||||||
|
*,
|
||||||
|
enable_byte_ngrams: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts
|
||||||
|
and fixing the batch.mask.
|
||||||
|
|
||||||
|
batch.patch_lengths has unchanged shape
|
||||||
|
x,y, and mask may reduce in size
|
||||||
|
"""
|
||||||
|
if batch.patch_lengths is None:
|
||||||
|
return batch
|
||||||
|
|
||||||
|
seq_lengths = batch.patch_lengths.sum(axis=1)
|
||||||
|
max_length_adj = max_length + 1
|
||||||
|
if np.any(seq_lengths > max_length_adj):
|
||||||
|
for i in range(batch.x.shape[0]):
|
||||||
|
if seq_lengths[i] > max_length_adj:
|
||||||
|
# Find id of patch that tips over max_length + 1
|
||||||
|
count, j = 0, 0
|
||||||
|
while count + batch.patch_lengths[i, j] <= max_length_adj:
|
||||||
|
count += batch.patch_lengths[i, j]
|
||||||
|
j += 1
|
||||||
|
# Edit the batch
|
||||||
|
assert j < batch.patch_lengths.shape[1]
|
||||||
|
batch.x[i, max_length:] = pad_id
|
||||||
|
batch.y[i, max_length:] = pad_id
|
||||||
|
if batch.mask is not None:
|
||||||
|
batch.mask[i, max_length:] = False
|
||||||
|
batch.patch_lengths[i, j:] = 0
|
||||||
|
batch.patch_lengths[i, j] = max_length_adj - count
|
||||||
|
|
||||||
|
# Truncate if necessary.
|
||||||
|
if max_length < batch.x.shape[1]:
|
||||||
|
batch.x = batch.x[:, :max_length]
|
||||||
|
batch.y = batch.y[:, :max_length]
|
||||||
|
if batch.mask is not None:
|
||||||
|
batch.mask = batch.mask[:, :max_length]
|
||||||
|
|
||||||
|
# Right pad to max_length if necessary
|
||||||
|
elif pad_to_max_length:
|
||||||
|
if batch.x.shape[1] < max_length:
|
||||||
|
# NOTE: this has to be done on an actual patch.
|
||||||
|
non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1
|
||||||
|
non_zero_indices = np.maximum(0, non_zero_indices)
|
||||||
|
batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += (
|
||||||
|
max_length - batch.x.shape[1]
|
||||||
|
)
|
||||||
|
# TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader.
|
||||||
|
x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype)
|
||||||
|
x[:, : batch.x.shape[1]] = batch.x
|
||||||
|
batch.x = x
|
||||||
|
if batch.y.shape[1] < max_length:
|
||||||
|
y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
|
||||||
|
y[:, : batch.y.shape[1]] = batch.y
|
||||||
|
batch.y = y
|
||||||
|
if batch.mask is not None and batch.mask.shape[1] < max_length:
|
||||||
|
mask = np.full(
|
||||||
|
(batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
|
||||||
|
)
|
||||||
|
mask[:, : batch.mask.shape[1]] = batch.mask
|
||||||
|
batch.mask = mask
|
||||||
|
|
||||||
|
assert batch.x.shape[1] <= max_length
|
||||||
|
assert batch.y.shape[1] <= max_length
|
||||||
|
assert batch.mask is None or batch.mask.shape[1] <= max_length
|
||||||
|
assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0)
|
||||||
|
if pad_to_max_length:
|
||||||
|
assert batch.x.shape[1] == max_length
|
||||||
|
assert batch.y.shape[1] == max_length
|
||||||
|
assert batch.mask is None or batch.mask.shape[1] == max_length
|
||||||
|
if enable_byte_ngrams:
|
||||||
|
raise NotImplementedError()
|
||||||
|
# (num_ngram, batch_size, seq_len)
|
||||||
|
ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x))
|
||||||
|
assert ngram_ids.shape[2] == batch.x.shape[1]
|
||||||
|
else:
|
||||||
|
ngram_ids = None
|
||||||
|
batch.ngram_ids = ngram_ids
|
||||||
|
|
||||||
|
|
||||||
|
class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sequence_iterator: StatefulIterator[BltSequence, Any],
|
||||||
|
*,
|
||||||
|
packing_args: PackingArgs,
|
||||||
|
):
|
||||||
|
self.sequence_iterator = sequence_iterator
|
||||||
|
self.packing_args = packing_args
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return PackingIteratorState(
|
||||||
|
sequence_iterator_state=self.sequence_iterator.get_state(),
|
||||||
|
packing_args=self.packing_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
sequence_iter = self.sequence_iterator.create_iter()
|
||||||
|
batch_size = self.packing_args.batch_size
|
||||||
|
pad_id = self.packing_args.pad_id
|
||||||
|
seq_len = self.packing_args.seq_len
|
||||||
|
pad_to_max_length = self.packing_args.pad_to_max_length
|
||||||
|
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
|
||||||
|
max_length = self.packing_args.max_length
|
||||||
|
while True:
|
||||||
|
tokens: list[list[int]] = []
|
||||||
|
masks: list[list[bool]] = []
|
||||||
|
patch_lengths: list[list[int]] = []
|
||||||
|
|
||||||
|
for _ in range(self.packing_args.batch_size):
|
||||||
|
sequence = next(sequence_iter)
|
||||||
|
_tokens = sequence.tokens
|
||||||
|
_mask = sequence.mask
|
||||||
|
_patch_lengths = sequence.patch_lengths
|
||||||
|
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
||||||
|
last_patch_length = 0
|
||||||
|
if _patch_lengths[0] > 1:
|
||||||
|
last_patch_length = _patch_lengths[-1]
|
||||||
|
_patch_lengths[0] -= 1
|
||||||
|
_patch_lengths = [1] + _patch_lengths[:-1]
|
||||||
|
tokens.append(_tokens[: len(_tokens) - last_patch_length])
|
||||||
|
masks.append(_mask[: len(_mask) - last_patch_length])
|
||||||
|
patch_lengths.append(_patch_lengths)
|
||||||
|
|
||||||
|
x_patch_lengths = np.array(patch_lengths)
|
||||||
|
# pad batch to same length
|
||||||
|
tok_seq_len = max([len(toks) for toks in tokens]) - 1
|
||||||
|
x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
|
||||||
|
y = np.full((batch_size, tok_seq_len), fill_value=pad_id)
|
||||||
|
|
||||||
|
for i, tok_seq in enumerate(tokens):
|
||||||
|
x[i, : len(tok_seq) - 1] = tok_seq[:-1]
|
||||||
|
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
||||||
|
# Adjust patch lengths to match x
|
||||||
|
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
|
||||||
|
|
||||||
|
assert x_patch_lengths.shape == (batch_size, seq_len)
|
||||||
|
|
||||||
|
if enable_byte_ngrams:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
ngram_ids = None
|
||||||
|
|
||||||
|
batch = Batch(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
patch_lengths=x_patch_lengths,
|
||||||
|
ngram_ids=ngram_ids,
|
||||||
|
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
x_patch_lengths.sum() == x.size + batch_size
|
||||||
|
), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
|
||||||
|
assert (
|
||||||
|
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
||||||
|
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
|
||||||
|
assert np.all(
|
||||||
|
x_patch_lengths[:, 0] == 1
|
||||||
|
), f"first patch should always be 1, {x_patch_lengths[:, 0]}"
|
||||||
|
# cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024)
|
||||||
|
# cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024
|
||||||
|
# print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}")
|
||||||
|
truncate_batch(
|
||||||
|
batch,
|
||||||
|
max_length=max_length,
|
||||||
|
pad_id=pad_id,
|
||||||
|
pad_to_max_length=pad_to_max_length,
|
||||||
|
enable_byte_ngrams=enable_byte_ngrams,
|
||||||
|
)
|
||||||
|
yield batch
|
111
bytelatent/data/iterators/preprocess_iterator.py
Normal file
111
bytelatent/data/iterators/preprocess_iterator.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.data.data_types import BltExample
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import (
|
||||||
|
ArrowFileIterator,
|
||||||
|
ArrowFileIteratorState,
|
||||||
|
)
|
||||||
|
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
|
||||||
|
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
|
||||||
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||||
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessIteratorState(BaseModel, IteratorState):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
|
||||||
|
add_tokens: bool
|
||||||
|
add_patches: bool
|
||||||
|
tokenizer_args: TokenizerArgs
|
||||||
|
patcher_args: PatcherArgs
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
arrow_iterator = self.arrow_file_iterator_state.build()
|
||||||
|
return PreprocessIterator(
|
||||||
|
arrow_iterator,
|
||||||
|
patcher_args=self.patcher_args,
|
||||||
|
tokenizer_args=self.tokenizer_args,
|
||||||
|
add_tokens=self.add_tokens,
|
||||||
|
add_patches=self.add_patches,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessIterator(StatefulIterator):
|
||||||
|
"""
|
||||||
|
Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
|
||||||
|
preprocessing like tokenization and patching
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
arrow_iterator: ArrowFileIterator,
|
||||||
|
*,
|
||||||
|
patcher_args: PatcherArgs,
|
||||||
|
tokenizer_args: TokenizerArgs,
|
||||||
|
add_tokens: bool = True,
|
||||||
|
add_patches: bool = True,
|
||||||
|
):
|
||||||
|
self.arrow_iterator = arrow_iterator
|
||||||
|
self.tokenizer_args = tokenizer_args
|
||||||
|
self.patcher_args = patcher_args
|
||||||
|
self.add_tokens = add_tokens
|
||||||
|
self.add_patches = add_patches
|
||||||
|
self.tokenizer: BltTokenizer | None = None
|
||||||
|
self.patcher: Patcher | None = None
|
||||||
|
|
||||||
|
def get_state(self) -> PreprocessIteratorState:
|
||||||
|
"""
|
||||||
|
The only state to maintain here is from arrow, there
|
||||||
|
isn't any internal state on this iterator.
|
||||||
|
"""
|
||||||
|
return PreprocessIteratorState(
|
||||||
|
arrow_file_iterator_state=self.arrow_iterator.get_state(),
|
||||||
|
tokenizer_args=self.tokenizer_args,
|
||||||
|
patcher_args=self.patcher_args,
|
||||||
|
add_tokens=self.add_tokens,
|
||||||
|
add_patches=self.add_patches,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_iter(self) -> Generator[BltExample, Any, None]:
|
||||||
|
if self.tokenizer is None and self.add_tokens:
|
||||||
|
self.tokenizer = self.tokenizer_args.build()
|
||||||
|
if self.patcher is None and self.add_patches:
|
||||||
|
self.patcher = self.patcher_args.build()
|
||||||
|
|
||||||
|
example_iter = self.arrow_iterator.create_iter()
|
||||||
|
for example in example_iter:
|
||||||
|
if self.add_tokens:
|
||||||
|
tokens = self.tokenizer.encode(example.text)
|
||||||
|
else:
|
||||||
|
tokens = example.tokens
|
||||||
|
if (
|
||||||
|
self.patcher is not None
|
||||||
|
and self.patcher.patching_mode == PatchingModeEnum.entropy
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
example.entropies is not None
|
||||||
|
), "For patching, entropies cannot be None"
|
||||||
|
entropies = torch.tensor(example.entropies).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
entropies = None
|
||||||
|
if self.patcher is None:
|
||||||
|
patch_lengths = None
|
||||||
|
else:
|
||||||
|
patch_lengths = self.patcher.patch(
|
||||||
|
torch.tensor(tokens).unsqueeze(0),
|
||||||
|
include_next_token=False,
|
||||||
|
entropies=entropies,
|
||||||
|
)[0][0].tolist()
|
||||||
|
yield BltExample(
|
||||||
|
sample_id=example.sample_id,
|
||||||
|
text=example.text,
|
||||||
|
tokens=tokens,
|
||||||
|
mask=[True] * len(tokens),
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
entropies=example.entropies,
|
||||||
|
)
|
66
bytelatent/data/iterators/sampling_iterator.py
Normal file
66
bytelatent/data/iterators/sampling_iterator.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||||
|
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingIteratorState(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
rng_state: dict[str, Any]
|
||||||
|
source_to_weight: dict[str, float]
|
||||||
|
source_to_iterator_state: dict[str, SequenceIteratorState]
|
||||||
|
|
||||||
|
def build(self) -> "SamplingIterator":
|
||||||
|
return SamplingIterator(
|
||||||
|
rng_state=self.rng_state,
|
||||||
|
source_to_weight=self.source_to_weight,
|
||||||
|
source_to_iterator={
|
||||||
|
source: state.build()
|
||||||
|
for source, state in self.source_to_iterator_state.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingIterator(StatefulIterator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
rng_state: dict[str, Any],
|
||||||
|
source_to_weight: dict[str, float],
|
||||||
|
source_to_iterator: dict[str, StatefulIterator],
|
||||||
|
):
|
||||||
|
self.rng = np.random.default_rng()
|
||||||
|
self.rng.bit_generator.state = rng_state
|
||||||
|
self.source_to_weight = source_to_weight
|
||||||
|
self.source_to_iterator = source_to_iterator
|
||||||
|
|
||||||
|
def get_state(self) -> SamplingIteratorState:
|
||||||
|
return SamplingIteratorState(
|
||||||
|
rng_state=self.rng.bit_generator.state,
|
||||||
|
source_to_weight=self.source_to_weight,
|
||||||
|
source_to_iterator_state={
|
||||||
|
source: iterator.get_state()
|
||||||
|
for source, iterator in self.source_to_iterator.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
n_sources = len(self.source_to_weight)
|
||||||
|
possible_sources = []
|
||||||
|
weights = []
|
||||||
|
for source, w in self.source_to_weight.items():
|
||||||
|
possible_sources.append(source)
|
||||||
|
weights.append(w)
|
||||||
|
|
||||||
|
source_to_python_iter = {
|
||||||
|
source: self.source_to_iterator[source].create_iter()
|
||||||
|
for source in possible_sources
|
||||||
|
}
|
||||||
|
while True:
|
||||||
|
norm_weights = np.array(weights) / np.array(weights).sum()
|
||||||
|
source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
|
||||||
|
yield next(source_to_python_iter[source_choice])
|
122
bytelatent/data/iterators/sequence_iterator.py
Normal file
122
bytelatent/data/iterators/sequence_iterator.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.data.data_types import BltSequence
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
from bytelatent.data.iterators.preprocess_iterator import (
|
||||||
|
PreprocessIterator,
|
||||||
|
PreprocessIteratorState,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class SequencePackingArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
output_seq_len: int
|
||||||
|
buffer_size: int
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceIteratorState(BaseModel, IteratorState):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
sequence_packing_args: SequencePackingArgs
|
||||||
|
preprocess_iterator_state: PreprocessIteratorState
|
||||||
|
rng_state: dict[str, Any]
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
preprocess_iterator = self.preprocess_iterator_state.build()
|
||||||
|
return SequenceIterator(
|
||||||
|
preprocess_iterator,
|
||||||
|
sequence_packing_args=self.sequence_packing_args,
|
||||||
|
rng_state=self.rng_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceIterator(StatefulIterator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
preprocess_iterator: PreprocessIterator,
|
||||||
|
*,
|
||||||
|
rng_state: dict[str, Any],
|
||||||
|
sequence_packing_args: SequencePackingArgs,
|
||||||
|
):
|
||||||
|
self.preprocess_iterator = preprocess_iterator
|
||||||
|
self.sequence_packing_args = sequence_packing_args
|
||||||
|
self.output_seq_len = sequence_packing_args.output_seq_len
|
||||||
|
self.buffer_size = sequence_packing_args.buffer_size
|
||||||
|
self.rng = np.random.default_rng()
|
||||||
|
self.rng.bit_generator.state = rng_state
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
# TODO: need to also perist the current shuffle buffer
|
||||||
|
return SequenceIteratorState(
|
||||||
|
sequence_packing_args=self.sequence_packing_args,
|
||||||
|
preprocess_iterator_state=self.preprocess_iterator.get_state(),
|
||||||
|
rng_state=self.rng.bit_generator.state,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
example_iter = self.preprocess_iterator.create_iter()
|
||||||
|
n_buffer_patches = self.buffer_size * self.output_seq_len
|
||||||
|
|
||||||
|
patch_lengths: list[int] = []
|
||||||
|
tokens: list[int] = []
|
||||||
|
mask: list[bool] = []
|
||||||
|
first = True
|
||||||
|
for example in example_iter:
|
||||||
|
assert example.tokens is not None
|
||||||
|
assert example.mask is not None
|
||||||
|
assert example.patch_lengths is not None
|
||||||
|
assert len(example.tokens) != 0
|
||||||
|
assert len(example.mask) != 0
|
||||||
|
assert len(example.tokens) == len(example.mask)
|
||||||
|
assert len(example.tokens) == sum(example.patch_lengths)
|
||||||
|
|
||||||
|
tokens.extend(example.tokens)
|
||||||
|
mask.extend(example.mask)
|
||||||
|
patch_lengths.extend(example.patch_lengths)
|
||||||
|
|
||||||
|
while len(patch_lengths) >= n_buffer_patches:
|
||||||
|
if first:
|
||||||
|
first = False
|
||||||
|
logger.info("First buffer complete")
|
||||||
|
|
||||||
|
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
||||||
|
self.buffer_size, self.output_seq_len
|
||||||
|
)
|
||||||
|
seq_tokens = []
|
||||||
|
seq_mask = []
|
||||||
|
start_id = 0
|
||||||
|
# We fix the number of patches and therefore global steps per batch
|
||||||
|
# so we have a variable number of tokens we need to account for
|
||||||
|
for num_tokens in x_patches.sum(axis=-1):
|
||||||
|
seq_tokens.append(tokens[start_id : start_id + num_tokens])
|
||||||
|
seq_mask.append(mask[start_id : start_id + num_tokens])
|
||||||
|
start_id += num_tokens
|
||||||
|
|
||||||
|
assert start_id == x_patches.sum()
|
||||||
|
|
||||||
|
# Remove what we just added from the buffer
|
||||||
|
patch_lengths = patch_lengths[n_buffer_patches:]
|
||||||
|
tokens = tokens[x_patches.sum() :]
|
||||||
|
mask = mask[x_patches.sum() :]
|
||||||
|
|
||||||
|
seq_patch_lengths: list[list[int]] = x_patches.tolist()
|
||||||
|
assert len(seq_patch_lengths) == self.buffer_size
|
||||||
|
for idx in self.rng.permutation(len(seq_patch_lengths)):
|
||||||
|
assert len(seq_patch_lengths[idx]) == self.output_seq_len
|
||||||
|
assert (
|
||||||
|
sum(seq_patch_lengths[idx])
|
||||||
|
== len(seq_tokens[idx])
|
||||||
|
== len(seq_mask[idx])
|
||||||
|
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
|
||||||
|
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
|
||||||
|
yield BltSequence(
|
||||||
|
tokens=seq_tokens[idx],
|
||||||
|
mask=seq_mask[idx],
|
||||||
|
patch_lengths=seq_patch_lengths[idx],
|
||||||
|
)
|
89
bytelatent/data/iterators/test_arrow_iterator.py
Normal file
89
bytelatent/data/iterators/test_arrow_iterator.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
# pyarrow needs the initialization from this import
|
||||||
|
import pyarrow.dataset # pyright: ignore
|
||||||
|
|
||||||
|
from bytelatent.constants import BLT_DATA
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
||||||
|
|
||||||
|
ENTROPY_MODEL = "transformer_100m"
|
||||||
|
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
||||||
|
ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow")
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_arrow_file():
|
||||||
|
dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow")
|
||||||
|
n_head = 1000
|
||||||
|
head_df = dataset.head(n_head).to_pandas()
|
||||||
|
|
||||||
|
initial_state = ArrowFileIteratorState(
|
||||||
|
file_path=None,
|
||||||
|
num_workers=1,
|
||||||
|
worker_id=0,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=ENTROPY_MODEL,
|
||||||
|
dataset_files=[ARROW_TEST_DATA_1],
|
||||||
|
row_num=0,
|
||||||
|
arrow_batch_size=100,
|
||||||
|
)
|
||||||
|
arrow_file = initial_state.build()
|
||||||
|
start_state = arrow_file.get_state()
|
||||||
|
assert start_state.row_num == initial_state.row_num
|
||||||
|
|
||||||
|
sample_id = None
|
||||||
|
for example in arrow_file.create_iter():
|
||||||
|
sample_id = example.sample_id
|
||||||
|
assert head_df.iloc[0]["sample_id"] == sample_id
|
||||||
|
break
|
||||||
|
|
||||||
|
assert arrow_file.get_state().row_num == 1
|
||||||
|
arrow_file = initial_state.build()
|
||||||
|
for example in arrow_file.create_iter():
|
||||||
|
assert example.sample_id == sample_id
|
||||||
|
assert head_df.iloc[0]["sample_id"] == sample_id
|
||||||
|
break
|
||||||
|
|
||||||
|
# Test resume far enough in to be past the batch size of 100
|
||||||
|
resumed_state = ArrowFileIteratorState(
|
||||||
|
file_path=None,
|
||||||
|
num_workers=1,
|
||||||
|
worker_id=0,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=ENTROPY_MODEL,
|
||||||
|
dataset_files=[ARROW_TEST_DATA_1],
|
||||||
|
row_num=251,
|
||||||
|
arrow_batch_size=100,
|
||||||
|
)
|
||||||
|
arrow_file = resumed_state.build()
|
||||||
|
for example in arrow_file.create_iter():
|
||||||
|
assert example.sample_id == head_df.iloc[251]["sample_id"]
|
||||||
|
assert arrow_file.get_state().row_num == 252
|
||||||
|
break
|
||||||
|
|
||||||
|
world_rank = 1
|
||||||
|
world_size = 4
|
||||||
|
# Test World Size and Rank
|
||||||
|
rank_state = ArrowFileIteratorState(
|
||||||
|
file_path=None,
|
||||||
|
num_workers=world_size,
|
||||||
|
worker_id=world_rank,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=ENTROPY_MODEL,
|
||||||
|
dataset_files=[ARROW_TEST_DATA_1],
|
||||||
|
row_num=0,
|
||||||
|
arrow_batch_size=100,
|
||||||
|
)
|
||||||
|
arrow_file = rank_state.build()
|
||||||
|
expected_ids = []
|
||||||
|
for i in range(n_head):
|
||||||
|
if i % world_size == world_rank:
|
||||||
|
expected_ids.append(head_df.iloc[i]["sample_id"])
|
||||||
|
print(len(expected_ids))
|
||||||
|
i = 0
|
||||||
|
for example in arrow_file.create_iter():
|
||||||
|
assert example.sample_id == expected_ids[i]
|
||||||
|
i += 1
|
||||||
|
if i >= len(expected_ids):
|
||||||
|
break
|
162
bytelatent/data/iterators/test_iters.py
Normal file
162
bytelatent/data/iterators/test_iters.py
Normal file
|
@ -0,0 +1,162 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import pandas as pd
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from bytelatent.constants import BLT_DATA
|
||||||
|
from bytelatent.data.data_types import BltExample
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
||||||
|
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
||||||
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
|
||||||
|
class BltTestIteratorState(BaseModel, IteratorState):
|
||||||
|
position: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
blt_iter = BltTestIteratorState(total=self.total)
|
||||||
|
blt_iter.position = self.position
|
||||||
|
return blt_iter
|
||||||
|
|
||||||
|
|
||||||
|
class BltTestIterator(StatefulIterator):
|
||||||
|
def __init__(self, total: int):
|
||||||
|
self.position = 0
|
||||||
|
self.total = total
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return BltTestIteratorState(position=self.position, total=self.total)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
for i in range(self.total):
|
||||||
|
self.position += 1
|
||||||
|
yield BltExample(
|
||||||
|
sample_id=f"test_{i}",
|
||||||
|
text=f"This is some test {i} text.",
|
||||||
|
tokens=None,
|
||||||
|
mask=None,
|
||||||
|
entropies=None,
|
||||||
|
patch_lengths=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
|
||||||
|
position: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
||||||
|
blt_iter.position = self.position
|
||||||
|
return blt_iter
|
||||||
|
|
||||||
|
|
||||||
|
class BltTestWithEntropiesIterator(StatefulIterator):
|
||||||
|
def __init__(self, total: int):
|
||||||
|
self.position = 0
|
||||||
|
self.total = total
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return BltTestIteratorState(position=self.position, total=self.total)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
||||||
|
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
||||||
|
tokens = df["token_ids"].tolist()
|
||||||
|
entropies = df["entropies"].tolist()
|
||||||
|
# BOS and EOS
|
||||||
|
assert len(tokens) == len(text) + 2
|
||||||
|
for i in range(self.total):
|
||||||
|
self.position += 1
|
||||||
|
yield BltExample(
|
||||||
|
sample_id=f"test_{i}",
|
||||||
|
text=text,
|
||||||
|
tokens=tokens,
|
||||||
|
mask=[True] * len(tokens),
|
||||||
|
entropies=entropies,
|
||||||
|
patch_lengths=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_iter():
|
||||||
|
total = 3
|
||||||
|
tokenizer_args = TokenizerArgs(
|
||||||
|
name="blt",
|
||||||
|
init_kwargs={
|
||||||
|
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for mode in [
|
||||||
|
PatchingModeEnum.bpe,
|
||||||
|
PatchingModeEnum.space,
|
||||||
|
]:
|
||||||
|
data_it = BltTestIterator(total)
|
||||||
|
patcher_args = PatcherArgs(patching_mode=mode)
|
||||||
|
example_it = PreprocessIterator(
|
||||||
|
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
||||||
|
)
|
||||||
|
count = 0
|
||||||
|
for example in example_it.create_iter():
|
||||||
|
assert isinstance(example.tokens, list)
|
||||||
|
assert isinstance(example.tokens[0], int)
|
||||||
|
# BOS and EOS
|
||||||
|
assert len(example.tokens) == len(example.text) + 2
|
||||||
|
assert example.mask is not None
|
||||||
|
assert len(example.tokens) == len(example.mask)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
assert count == total
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_entropy_patch_iter():
|
||||||
|
total = 3
|
||||||
|
tokenizer_args = TokenizerArgs(
|
||||||
|
name="blt",
|
||||||
|
init_kwargs={
|
||||||
|
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for mode in [
|
||||||
|
PatchingModeEnum.bpe,
|
||||||
|
PatchingModeEnum.space,
|
||||||
|
]:
|
||||||
|
patcher_args = PatcherArgs(patching_mode=mode)
|
||||||
|
data_it = BltTestIterator(total)
|
||||||
|
example_it = PreprocessIterator(
|
||||||
|
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for example in example_it.create_iter():
|
||||||
|
assert isinstance(example.patch_lengths, list)
|
||||||
|
assert isinstance(example.patch_lengths[0], int)
|
||||||
|
assert len(example.tokens) == sum(example.patch_lengths)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
assert count == total
|
||||||
|
|
||||||
|
|
||||||
|
def test_entropy_patch_iter():
|
||||||
|
total = 2
|
||||||
|
patcher_args = PatcherArgs(
|
||||||
|
patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
|
||||||
|
)
|
||||||
|
tokenizer_args = TokenizerArgs(
|
||||||
|
name="blt",
|
||||||
|
init_kwargs={
|
||||||
|
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data_it = BltTestWithEntropiesIterator(total)
|
||||||
|
example_it = PreprocessIterator(
|
||||||
|
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for example in example_it.create_iter():
|
||||||
|
assert isinstance(example.patch_lengths, list)
|
||||||
|
assert isinstance(example.patch_lengths[0], int)
|
||||||
|
assert len(example.tokens) == sum(example.patch_lengths)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
assert count == total
|
146
bytelatent/data/ngram_processor.py
Normal file
146
bytelatent/data/ngram_processor.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from bytelatent import ByteLatentError
|
||||||
|
|
||||||
|
LOOKUP_OFFSET = 4
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1):
|
||||||
|
"""
|
||||||
|
Wrapper function for applying the lookup table to each n-gram.
|
||||||
|
|
||||||
|
:param ngram: Array of numbers representing an n-gram.
|
||||||
|
:param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs.
|
||||||
|
:param lookup_offset: Offset to add to the lookup result.
|
||||||
|
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def apply_lookup_table(ngram):
|
||||||
|
"""
|
||||||
|
Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary.
|
||||||
|
|
||||||
|
:param ngram: Array of numbers representing an n-gram.
|
||||||
|
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
|
||||||
|
"""
|
||||||
|
# Convert the n-gram to a tuple
|
||||||
|
ngram_tuple = tuple(ngram)
|
||||||
|
|
||||||
|
if ngram_tuple not in ngram_to_idx:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return ngram_to_idx[ngram_tuple] + lookup_offset
|
||||||
|
|
||||||
|
return apply_lookup_table
|
||||||
|
|
||||||
|
|
||||||
|
def get_byte_ngrams_ids(
|
||||||
|
byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate n-grams from a 2D numpy array.
|
||||||
|
|
||||||
|
:param n: The length of each n-gram.
|
||||||
|
:param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams.
|
||||||
|
:return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET.
|
||||||
|
"""
|
||||||
|
num_rows, num_cols = byte_array.shape
|
||||||
|
|
||||||
|
# Create an array to hold the padded version of the original array
|
||||||
|
padded_array = np.pad(
|
||||||
|
byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use stride tricks to avoid explicit looping
|
||||||
|
strided = np.lib.stride_tricks.as_strided
|
||||||
|
shape = (num_rows, num_cols, n)
|
||||||
|
strides = padded_array.strides[:2] + (padded_array.strides[1],)
|
||||||
|
ngrams = strided(padded_array, shape=shape, strides=strides)
|
||||||
|
|
||||||
|
ngram_ids = np.apply_along_axis(
|
||||||
|
apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams
|
||||||
|
)
|
||||||
|
assert ngram_ids.shape == byte_array.shape
|
||||||
|
return ngram_ids
|
||||||
|
|
||||||
|
|
||||||
|
def reload_tables(
|
||||||
|
ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET
|
||||||
|
) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]:
|
||||||
|
"""
|
||||||
|
Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram,
|
||||||
|
only load up to the max specified size. Return the actual number of ngrams taken per ngram size.
|
||||||
|
"""
|
||||||
|
idx_to_ngram_tables = {}
|
||||||
|
ngram_to_idx_tables = {}
|
||||||
|
vocab_sizes = {}
|
||||||
|
for ngram, size in ngram_to_size.items():
|
||||||
|
with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f:
|
||||||
|
# These are already sorted by count
|
||||||
|
# Value: tuple of: count, ngram, dataset
|
||||||
|
ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[
|
||||||
|
"counts"
|
||||||
|
]
|
||||||
|
table = [ngram for ngram, _ in ngram_data][:size]
|
||||||
|
if len(table) != size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}"
|
||||||
|
)
|
||||||
|
ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)}
|
||||||
|
actual_size = len(table)
|
||||||
|
idx_to_ngram_tables[ngram] = table
|
||||||
|
ngram_to_idx_tables[ngram] = ngram_to_idx
|
||||||
|
vocab_sizes[ngram] = actual_size + offset
|
||||||
|
return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
|
||||||
|
if ngram_to_size_str is None:
|
||||||
|
return None
|
||||||
|
ngram_to_size = {}
|
||||||
|
for entry in ngram_to_size_str.split(","):
|
||||||
|
ngram, size = entry.split(":")
|
||||||
|
ngram = int(ngram)
|
||||||
|
size = int(size)
|
||||||
|
ngram_to_size[ngram] = size
|
||||||
|
return ngram_to_size
|
||||||
|
|
||||||
|
|
||||||
|
class NgramProcessor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ngram_table_dir: str | None = None,
|
||||||
|
ngram_to_size: dict[int, int] | None = None,
|
||||||
|
):
|
||||||
|
if ngram_table_dir is None or ngram_to_size is None:
|
||||||
|
raise ByteLatentError(
|
||||||
|
"ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True"
|
||||||
|
)
|
||||||
|
(
|
||||||
|
self.ngram_to_idx_tables,
|
||||||
|
self.idx_to_ngram_tables,
|
||||||
|
self.ngram_vocab_sizes,
|
||||||
|
) = reload_tables(ngram_table_dir, ngram_to_size)
|
||||||
|
# Lowest to highest ngram
|
||||||
|
self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys()))
|
||||||
|
# Although the model might not use all the ngrams, we need the tokenizer
|
||||||
|
# to produce ngram_ids such that index zero is the 2-gram, later on in
|
||||||
|
# src.model.megabyte.Megabyte.forward
|
||||||
|
assert self.ngram_sizes[0] == 2
|
||||||
|
|
||||||
|
def encode_single_ngram_table(self, data: np.ndarray, n: int):
|
||||||
|
"""
|
||||||
|
Return the n-grams of the input data for a given n
|
||||||
|
numpy array with ids of shape data.shape
|
||||||
|
"""
|
||||||
|
return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0)
|
||||||
|
|
||||||
|
def encode_token_ngrams(self, data: np.ndarray):
|
||||||
|
"""
|
||||||
|
Return the n-grams of the input data.
|
||||||
|
output shape: [ids with data.shape for n in self.ngram_sizes]
|
||||||
|
"""
|
||||||
|
return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes]
|
609
bytelatent/data/patcher.py
Normal file
609
bytelatent/data/patcher.py
Normal file
|
@ -0,0 +1,609 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from bytelatent.distributed import get_local_rank
|
||||||
|
from bytelatent.entropy_model import load_entropy_model
|
||||||
|
|
||||||
|
# from src.slurm import get_local_rank
|
||||||
|
from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
|
||||||
|
from bytelatent.tokenizers.constants import BPE_ID, OFFSET
|
||||||
|
|
||||||
|
|
||||||
|
class PatchingModeEnum(str, Enum):
|
||||||
|
entropy = "entropy"
|
||||||
|
bpe = "bpe"
|
||||||
|
bpe_patcher = "bpe_patcher"
|
||||||
|
space = "space"
|
||||||
|
|
||||||
|
|
||||||
|
class PatcherArgs(BaseModel):
|
||||||
|
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
|
||||||
|
patching_device: str = "cuda"
|
||||||
|
entropy_model_checkpoint_dir: str | None = None
|
||||||
|
realtime_patching: bool = False
|
||||||
|
threshold: float = 1.335442066192627
|
||||||
|
threshold_add: float | None = None
|
||||||
|
max_patch_length: int | None = None
|
||||||
|
patch_size: float = 4.5
|
||||||
|
patching_batch_size: int = 1
|
||||||
|
data_loader_patching: bool = False
|
||||||
|
device: str = "cuda"
|
||||||
|
monotonicity: bool = False
|
||||||
|
log_time: bool = False
|
||||||
|
|
||||||
|
def build(self) -> "Patcher":
|
||||||
|
return Patcher(self)
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(scores):
|
||||||
|
"""
|
||||||
|
scores: [bs, seq_len, vocab]
|
||||||
|
returns [bs, seq_len]
|
||||||
|
|
||||||
|
Computes the entropy for each token in the batch.
|
||||||
|
Note: uses natural log.
|
||||||
|
"""
|
||||||
|
log_probs = F.log_softmax(scores, dim=-1)
|
||||||
|
probs = torch.exp(log_probs)
|
||||||
|
p_log_p = log_probs * probs
|
||||||
|
entropy = -p_log_p.sum(dim=-1)
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_entropies(
|
||||||
|
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
tokens: 2D tensor of shape [batch_size, seq_len]
|
||||||
|
Return 2D tensor of shape [batch_size, seq_len] with entropies for each token.
|
||||||
|
|
||||||
|
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
||||||
|
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
entropies = []
|
||||||
|
max_length = getattr(entropy_model, "max_length", 8192)
|
||||||
|
batch_numel = max_length * patching_batch_size
|
||||||
|
splits = torch.split(tokens.flatten(), batch_numel)
|
||||||
|
for split in splits:
|
||||||
|
pad_size = (max_length - (split.numel() % max_length)) % max_length
|
||||||
|
pad = torch.zeros(
|
||||||
|
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
|
||||||
|
)
|
||||||
|
split = torch.cat((split, pad), dim=0)
|
||||||
|
split = split.reshape(-1, max_length)
|
||||||
|
if device is not None:
|
||||||
|
split = split.to(device)
|
||||||
|
assert torch.all(split >= 0) and torch.all(split < 260)
|
||||||
|
pred, _ = entropy_model(split)
|
||||||
|
pred = pred.reshape(-1, pred.shape[-1])[
|
||||||
|
: split.numel() - pad_size, :
|
||||||
|
] # [batch_size * seq_len, vocab]
|
||||||
|
pred_entropies = entropy(pred)
|
||||||
|
entropies.append(pred_entropies)
|
||||||
|
|
||||||
|
entropies = torch.cat(entropies, dim=0)
|
||||||
|
entropies = entropies.reshape(tokens.shape)
|
||||||
|
return entropies
|
||||||
|
|
||||||
|
|
||||||
|
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
||||||
|
"""
|
||||||
|
entropies: [bs, seq_len] torch tensor of entropies
|
||||||
|
t: threshold
|
||||||
|
returns [bs, seq_len] mask where True indicates the start of a patch
|
||||||
|
"""
|
||||||
|
bs, seq_len = entropies.shape
|
||||||
|
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
||||||
|
mask[:, 0] = True
|
||||||
|
|
||||||
|
# Calculate differences between consecutive elements along the sequence length
|
||||||
|
differences = entropies[:, 1:] - entropies[:, :-1]
|
||||||
|
|
||||||
|
# Calculate conditions for all elements except the first one in each sequence
|
||||||
|
condition = differences > t
|
||||||
|
|
||||||
|
# Update the mask based on the condition
|
||||||
|
mask[:, 1:] = condition
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
|
||||||
|
"""
|
||||||
|
entropies: [bs, seq_len] torch tensor of entropies
|
||||||
|
t: threshold
|
||||||
|
returns [bs, seq_len] mask where True indicates the start of a patch
|
||||||
|
"""
|
||||||
|
bs, seq_len = entropies.shape
|
||||||
|
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
||||||
|
mask[:, 0] = True
|
||||||
|
|
||||||
|
# Calculate differences between consecutive elements along the sequence length
|
||||||
|
differences = entropies[:, 1:] - entropies[:, :-1]
|
||||||
|
|
||||||
|
# Calculate conditions for all elements except the first one in each sequence
|
||||||
|
condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1])
|
||||||
|
|
||||||
|
# Update the mask based on the condition
|
||||||
|
mask[:, 1:] = condition
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def patch_start_ids_from_patch_start_mask(patch_start_mask):
|
||||||
|
bs, trunc_seq_len = patch_start_mask.shape
|
||||||
|
max_patches = patch_start_mask.sum(dim=1).max()
|
||||||
|
if max_patches == 0:
|
||||||
|
patch_start_ids = torch.full(
|
||||||
|
(bs, trunc_seq_len),
|
||||||
|
trunc_seq_len,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=patch_start_mask.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
patch_ids = (
|
||||||
|
torch.arange(trunc_seq_len, device=patch_start_mask.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(bs, 1)
|
||||||
|
)
|
||||||
|
extra_patch_ids = torch.full(
|
||||||
|
(bs, trunc_seq_len),
|
||||||
|
trunc_seq_len,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=patch_start_mask.device,
|
||||||
|
)
|
||||||
|
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
|
||||||
|
patch_start_mask_padded = torch.cat(
|
||||||
|
(patch_start_mask, ~patch_start_mask), dim=1
|
||||||
|
)
|
||||||
|
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
|
||||||
|
bs, trunc_seq_len
|
||||||
|
)[:, :max_patches]
|
||||||
|
return patch_start_ids
|
||||||
|
|
||||||
|
|
||||||
|
def check_non_zero_after_zero(tensor):
|
||||||
|
zero_mask = tensor == 0
|
||||||
|
shifted_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
|
||||||
|
zero_mask[:, :-1],
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
non_zero_after_zero = (tensor != 0) & shifted_mask
|
||||||
|
return non_zero_after_zero.any()
|
||||||
|
|
||||||
|
|
||||||
|
def patch_lengths_from_start_ids(patch_start_ids, seq_len):
|
||||||
|
"""
|
||||||
|
Calculate patch lengths from start ids.
|
||||||
|
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
|
||||||
|
the rest are filled to the seq len.
|
||||||
|
seq_len: ex: 7 length of the sequence
|
||||||
|
|
||||||
|
returns the patch lengths:
|
||||||
|
[1, 6] for the above example.
|
||||||
|
"""
|
||||||
|
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
|
||||||
|
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
|
||||||
|
patch_lengths = patch_end_ids - patch_start_ids + 1
|
||||||
|
assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
|
||||||
|
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
|
||||||
|
return patch_lengths
|
||||||
|
|
||||||
|
|
||||||
|
def find_space_patch_start_ids(tokens):
|
||||||
|
bs, seq_len = tokens.shape
|
||||||
|
tokens_no_offset = tokens - OFFSET
|
||||||
|
patch_end_mask = (
|
||||||
|
(tokens_no_offset < ord("0"))
|
||||||
|
| ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A")))
|
||||||
|
| ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a")))
|
||||||
|
| ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000))
|
||||||
|
| (0b1100_0000 <= tokens_no_offset)
|
||||||
|
)
|
||||||
|
patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not()
|
||||||
|
patch_end_mask |= tokens < OFFSET
|
||||||
|
|
||||||
|
patch_start_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([1, 1], device=tokens.device, dtype=torch.bool)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(bs, 1),
|
||||||
|
patch_end_mask[:, 1:],
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
max_patches = patch_start_mask.sum(dim=1).max()
|
||||||
|
|
||||||
|
patch_ids = (
|
||||||
|
torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1)
|
||||||
|
)
|
||||||
|
extra_patch_ids = torch.full(
|
||||||
|
(bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device
|
||||||
|
)
|
||||||
|
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
|
||||||
|
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
|
||||||
|
|
||||||
|
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[
|
||||||
|
:, :max_patches
|
||||||
|
]
|
||||||
|
return patch_start_ids
|
||||||
|
|
||||||
|
|
||||||
|
def to_device(entropy_model, device=None):
|
||||||
|
if device == "cuda":
|
||||||
|
rank = get_local_rank()
|
||||||
|
device = f"cuda:{rank}"
|
||||||
|
entropy_model = entropy_model.to(device)
|
||||||
|
return entropy_model, device
|
||||||
|
|
||||||
|
|
||||||
|
def model_pred_to_bpe_patching_pred(pred):
|
||||||
|
_, indices = torch.max(pred, dim=1)
|
||||||
|
return indices == BPE_ID
|
||||||
|
|
||||||
|
|
||||||
|
def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None):
|
||||||
|
assert tokens.device == torch.device(
|
||||||
|
"cpu"
|
||||||
|
), f"{tokens.device} != cpu expects tokens to be on cpu"
|
||||||
|
with torch.no_grad():
|
||||||
|
bpe_patcher_device, device = to_device(
|
||||||
|
bpe_patcher, device
|
||||||
|
) # Get entropy model to right rank device.
|
||||||
|
bpe_patching_mask = []
|
||||||
|
max_length = getattr(bpe_patcher, "max_length", 8192)
|
||||||
|
batch_numel = max_length * patching_batch_size
|
||||||
|
splits = torch.split(tokens.flatten(), batch_numel)
|
||||||
|
for split in splits:
|
||||||
|
pad_size = (max_length - (split.numel() % max_length)) % max_length
|
||||||
|
pad = torch.zeros(
|
||||||
|
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
|
||||||
|
)
|
||||||
|
split = torch.cat((split, pad), dim=0)
|
||||||
|
split = split.reshape(-1, max_length).to(device)
|
||||||
|
assert torch.all(split >= 0) and torch.all(split < 260)
|
||||||
|
pred = bpe_patcher_device(split)
|
||||||
|
pred_cpu = pred[0].cpu()
|
||||||
|
pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[
|
||||||
|
: split.numel() - pad_size, :
|
||||||
|
] # [batch_size * seq_len, vocab]
|
||||||
|
bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu)
|
||||||
|
bpe_patching_mask.append(bpe_patching_pred)
|
||||||
|
bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0)
|
||||||
|
bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape)
|
||||||
|
return bpe_patching_mask
|
||||||
|
|
||||||
|
|
||||||
|
def find_bpe_patcher_patch_start_ids(
|
||||||
|
tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True
|
||||||
|
):
|
||||||
|
bs, seq_len = tokens.shape
|
||||||
|
|
||||||
|
first_ids = (
|
||||||
|
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(bs, 1)
|
||||||
|
)
|
||||||
|
preds_truncation_len = first_ids.shape[1]
|
||||||
|
token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1]
|
||||||
|
if token_input.shape[1] >= 1:
|
||||||
|
patch_start_mask = apply_bpe_patcher(
|
||||||
|
token_input, bpe_patcher, patching_batch_size, device
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
patch_start_mask.shape[1]
|
||||||
|
== tokens.shape[1] + include_next_token - preds_truncation_len
|
||||||
|
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
|
||||||
|
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
||||||
|
patch_start_ids = torch.cat(
|
||||||
|
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
patch_start_ids = first_ids
|
||||||
|
return patch_start_ids
|
||||||
|
|
||||||
|
|
||||||
|
def find_entropy_patch_start_ids(
|
||||||
|
entropies,
|
||||||
|
patch_size=None,
|
||||||
|
threshold=None,
|
||||||
|
threshold_add=None,
|
||||||
|
monotonicity=False,
|
||||||
|
include_next_token=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use entropies to find the start ids of each patch.
|
||||||
|
Use patch_size or threshold to figure out the total number of patches to allocate.
|
||||||
|
|
||||||
|
When threshold is not None the number of patches is not constant between
|
||||||
|
different sequences, but patches can be identified incrementally rather than
|
||||||
|
decided globally using the entire sequence.
|
||||||
|
"""
|
||||||
|
bs, seq_len = entropies.shape[:2]
|
||||||
|
|
||||||
|
first_ids = (
|
||||||
|
torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(bs, 1)
|
||||||
|
)
|
||||||
|
preds_truncation_len = first_ids.shape[
|
||||||
|
1
|
||||||
|
] # remove the first preds because they will be start of patches.
|
||||||
|
entropies = entropies[:, 1:]
|
||||||
|
if threshold is None:
|
||||||
|
num_patches = seq_len // patch_size
|
||||||
|
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
|
||||||
|
patch_start_ids = patch_start_ids.sort(dim=1).values
|
||||||
|
else:
|
||||||
|
# Assumes that there is at least one token going over the threshold
|
||||||
|
if monotonicity:
|
||||||
|
patch_start_mask = patch_start_mask_from_entropy_with_monotonicity(
|
||||||
|
entropies, threshold
|
||||||
|
)
|
||||||
|
elif threshold_add is not None and threshold is not None:
|
||||||
|
patch_start_mask = patch_start_mask_global_and_monotonicity(
|
||||||
|
entropies, threshold, threshold_add
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
patch_start_mask = entropies > threshold
|
||||||
|
if not include_next_token:
|
||||||
|
patch_start_mask = patch_start_mask[:, :-1]
|
||||||
|
# patch_start_mask[1:] |= tokens[:-1] < OFFSET
|
||||||
|
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
||||||
|
|
||||||
|
patch_start_ids = torch.cat(
|
||||||
|
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
||||||
|
)
|
||||||
|
return patch_start_ids
|
||||||
|
|
||||||
|
|
||||||
|
def rightpad(seq, pad_id, max_len):
|
||||||
|
return seq + [pad_id] * (max_len - len(seq))
|
||||||
|
|
||||||
|
|
||||||
|
def find_bpe_delim_patch_start_ids(tokens, delim):
|
||||||
|
ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False)
|
||||||
|
out = [[0, 1] for _ in range(tokens.shape[0])]
|
||||||
|
for x, y in ids:
|
||||||
|
# start is at delim + 1, delim should be the last element in the patch.
|
||||||
|
out[x.item()].append(y.item() + 1)
|
||||||
|
max_len = max([len(elt) for elt in out])
|
||||||
|
out = [rightpad(elt, tokens.shape[1], max_len) for elt in out]
|
||||||
|
patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device)
|
||||||
|
return patch_start_ids
|
||||||
|
|
||||||
|
|
||||||
|
def find_lookup_table_start_mask(
|
||||||
|
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
|
||||||
|
):
|
||||||
|
window_size = lookup_table.ndim
|
||||||
|
# Unfold the tensor to get sliding windows
|
||||||
|
unfolded = tokens.unfold(1, window_size, 1)
|
||||||
|
# Gather indices for each dimension
|
||||||
|
indices = [unfolded[..., i] for i in range(window_size)]
|
||||||
|
# Access the lookup table using the gathered indices
|
||||||
|
result = lookup_table[indices]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def find_lookup_table_patch_start_ids(
|
||||||
|
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
|
||||||
|
):
|
||||||
|
bs, seq_len = tokens.shape
|
||||||
|
|
||||||
|
first_ids = (
|
||||||
|
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(bs, 1)
|
||||||
|
)
|
||||||
|
preds_truncation_len = first_ids.shape[1]
|
||||||
|
window_size = lookup_table.ndim
|
||||||
|
assert window_size == 2, f"{window_size} != 2"
|
||||||
|
# output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape
|
||||||
|
token_input = (
|
||||||
|
tokens if include_next_token else tokens[:, : -preds_truncation_len + 1]
|
||||||
|
)
|
||||||
|
if token_input.shape[1] >= window_size:
|
||||||
|
patch_start_mask = find_lookup_table_start_mask(
|
||||||
|
token_input, lookup_table, include_next_token
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
patch_start_mask.shape[1]
|
||||||
|
== tokens.shape[1] + include_next_token - preds_truncation_len
|
||||||
|
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
|
||||||
|
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
||||||
|
patch_start_ids = torch.cat(
|
||||||
|
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
patch_start_ids = first_ids
|
||||||
|
return patch_start_ids
|
||||||
|
|
||||||
|
|
||||||
|
def split_large_numbers(lst, m):
|
||||||
|
new_lst = []
|
||||||
|
for i in lst:
|
||||||
|
if i > m:
|
||||||
|
while i > m:
|
||||||
|
new_lst.append(m)
|
||||||
|
i -= m
|
||||||
|
new_lst.append(i)
|
||||||
|
else:
|
||||||
|
new_lst.append(i)
|
||||||
|
assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
|
||||||
|
return new_lst
|
||||||
|
|
||||||
|
|
||||||
|
class Patcher:
|
||||||
|
def __init__(self, patcher_args: PatcherArgs):
|
||||||
|
self.patcher_args = patcher_args
|
||||||
|
self.patching_mode = patcher_args.patching_mode
|
||||||
|
self.realtime_patching = patcher_args.realtime_patching
|
||||||
|
if self.realtime_patching:
|
||||||
|
assert (
|
||||||
|
patcher_args.entropy_model_checkpoint_dir is not None
|
||||||
|
), "Cannot require realtime patching without an entropy model checkpoint"
|
||||||
|
entropy_model = load_entropy_model(
|
||||||
|
patcher_args.entropy_model_checkpoint_dir
|
||||||
|
)
|
||||||
|
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
||||||
|
self.entropy_model = entropy_model
|
||||||
|
else:
|
||||||
|
self.entropy_model = None
|
||||||
|
self.threshold = patcher_args.threshold
|
||||||
|
self.threshold_add = patcher_args.threshold_add
|
||||||
|
self.max_patch_length = patcher_args.max_patch_length
|
||||||
|
self.patch_size = patcher_args.patch_size
|
||||||
|
self.patching_batch_size = patcher_args.patching_batch_size
|
||||||
|
self.data_loader_patching = patcher_args.data_loader_patching
|
||||||
|
self.device = patcher_args.device
|
||||||
|
self.monotonicity = patcher_args.monotonicity
|
||||||
|
self.log_time = patcher_args.log_time
|
||||||
|
if self.log_time:
|
||||||
|
self.log = defaultdict(float)
|
||||||
|
|
||||||
|
def patch(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
include_next_token: bool = False,
|
||||||
|
preds: torch.Tensor | None = None,
|
||||||
|
entropies: torch.Tensor | None = None,
|
||||||
|
threshold: float = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
|
||||||
|
Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
|
||||||
|
-> output tensor: [batch_size, max_num_patches]
|
||||||
|
each tensor is processed independently and gets right padded with zeros.
|
||||||
|
|
||||||
|
Patching with the following modes:
|
||||||
|
1. patching_mode = None: static patch size
|
||||||
|
2. patching_mode = "entropy":
|
||||||
|
calculate entropy of each token, allocate patches so that the total
|
||||||
|
number of patches is the same as static patching but choose to begin
|
||||||
|
patches on tokens where the model is most uncertain (highest entropy).
|
||||||
|
|
||||||
|
When threshold is provided, it uses the threshold to decide when to
|
||||||
|
start a new patch.
|
||||||
|
3. patching_mode = "space":
|
||||||
|
use space like tokens to define the patches.
|
||||||
|
4. patching_mode = "bpe":
|
||||||
|
use bpe delim tokens to define the patches.
|
||||||
|
|
||||||
|
To correctly patch the last token, it may be necessary to include the next token in the patch
|
||||||
|
lengths calculations. This is controlled by the include_next_token argument.
|
||||||
|
"""
|
||||||
|
bs, seq_len = tokens.shape
|
||||||
|
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
||||||
|
scores = None
|
||||||
|
# STATIC
|
||||||
|
if self.patching_mode is None:
|
||||||
|
patch_lengths = torch.zeros(
|
||||||
|
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
||||||
|
dtype=tokens.dtype,
|
||||||
|
device=tokens.device,
|
||||||
|
).fill_(self.patch_size)
|
||||||
|
if seq_len_next_tok % self.patch_size != 0:
|
||||||
|
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
||||||
|
# ENTROPY
|
||||||
|
elif self.patching_mode == PatchingModeEnum.entropy:
|
||||||
|
if self.log_time:
|
||||||
|
s = time.time()
|
||||||
|
if entropies is not None:
|
||||||
|
scores = torch.tensor(entropies, dtype=torch.float32)
|
||||||
|
elif preds is not None:
|
||||||
|
scores = entropy(preds)
|
||||||
|
else:
|
||||||
|
start_entropies = time.time()
|
||||||
|
scores = calculate_entropies(
|
||||||
|
tokens,
|
||||||
|
self.entropy_model,
|
||||||
|
self.patching_batch_size,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
if self.log_time:
|
||||||
|
self.log["calculate_entropies"] += time.time() - s
|
||||||
|
s = time.time()
|
||||||
|
patch_start_ids = find_entropy_patch_start_ids(
|
||||||
|
scores,
|
||||||
|
self.patch_size,
|
||||||
|
include_next_token=include_next_token,
|
||||||
|
threshold=threshold if threshold is not None else self.threshold,
|
||||||
|
threshold_add=self.threshold_add,
|
||||||
|
monotonicity=self.monotonicity,
|
||||||
|
)
|
||||||
|
if self.log_time:
|
||||||
|
self.log["find_entropy_patch_start_ids"] += time.time() - s
|
||||||
|
s = time.time()
|
||||||
|
patch_lengths = patch_lengths_from_start_ids(
|
||||||
|
patch_start_ids, seq_len_next_tok
|
||||||
|
)
|
||||||
|
if self.log_time:
|
||||||
|
self.log["patch_lengths_from_start_ids"] += time.time() - s
|
||||||
|
s = time.time()
|
||||||
|
# BPE
|
||||||
|
elif self.patching_mode == PatchingModeEnum.bpe:
|
||||||
|
patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID)
|
||||||
|
patch_lengths = patch_lengths_from_start_ids(
|
||||||
|
patch_start_ids, seq_len_next_tok
|
||||||
|
)
|
||||||
|
elif self.patching_mode == PatchingModeEnum.bpe_patcher:
|
||||||
|
patch_start_ids = find_bpe_patcher_patch_start_ids(
|
||||||
|
tokens,
|
||||||
|
self.entropy_model,
|
||||||
|
self.patching_batch_size,
|
||||||
|
self.device,
|
||||||
|
include_next_token,
|
||||||
|
)
|
||||||
|
patch_lengths = patch_lengths_from_start_ids(
|
||||||
|
patch_start_ids, seq_len_next_tok
|
||||||
|
)
|
||||||
|
# SPACE
|
||||||
|
elif self.patching_mode == PatchingModeEnum.space:
|
||||||
|
patch_start_ids = find_space_patch_start_ids(tokens)
|
||||||
|
patch_lengths = patch_lengths_from_start_ids(
|
||||||
|
patch_start_ids, seq_len_next_tok
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
|
||||||
|
|
||||||
|
# Apply any processing to patch lengths
|
||||||
|
if self.max_patch_length is not None:
|
||||||
|
# TODO: avoid going back to a list here.
|
||||||
|
patch_lengths = [
|
||||||
|
split_large_numbers(pl, self.max_patch_length)
|
||||||
|
for pl in patch_lengths.tolist()
|
||||||
|
]
|
||||||
|
max_len = max([len(pl) for pl in patch_lengths])
|
||||||
|
patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
|
||||||
|
patch_lengths = torch.tensor(
|
||||||
|
patch_lengths, dtype=tokens.dtype, device=tokens.device
|
||||||
|
)
|
||||||
|
assert not check_non_zero_after_zero(patch_lengths)
|
||||||
|
# Find the last non-zero column index using argmax on a reversed version of the tensor
|
||||||
|
last_non_zero_col_reversed = (
|
||||||
|
(patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
|
||||||
|
)
|
||||||
|
# Slice the tensor up to the last non-zero column
|
||||||
|
patch_lengths = patch_lengths[
|
||||||
|
:, : patch_lengths.shape[1] - last_non_zero_col_reversed
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
torch.sum(patch_lengths)
|
||||||
|
== tokens.numel() + include_next_token * tokens.shape[0]
|
||||||
|
), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
|
||||||
|
if self.log_time:
|
||||||
|
self.log["postprocessing_patch_lengths"] += time.time() - s
|
||||||
|
self.log["tokens"] += patch_lengths.sum().item()
|
||||||
|
return patch_lengths, scores
|
478
bytelatent/distributed.py
Normal file
478
bytelatent/distributed.py
Normal file
|
@ -0,0 +1,478 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from functools import lru_cache, partial, reduce
|
||||||
|
from itertools import chain
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# for no recompute ops
|
||||||
|
import xformers.ops
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from torch import distributed as dist
|
||||||
|
from torch.distributed import ReduceOp
|
||||||
|
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
|
||||||
|
from torch.distributed._tensor import DTensor
|
||||||
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
|
checkpoint_wrapper,
|
||||||
|
)
|
||||||
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.utils.checkpoint import (
|
||||||
|
CheckpointPolicy,
|
||||||
|
create_selective_checkpoint_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
from bytelatent.float8 import convert_linears_to_fp8
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
# for selective AC
|
||||||
|
default_no_recompute_ops = {
|
||||||
|
torch.ops.aten.mm.default,
|
||||||
|
torch.ops.aten._scaled_mm.default,
|
||||||
|
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
||||||
|
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
||||||
|
torch.ops.c10d_functional.reduce_scatter_tensor.default,
|
||||||
|
torch.ops.xformers_flash.flash_fwd.default,
|
||||||
|
torch.ops.xformers.efficient_attention_forward_cutlass.default,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
dp_shard: int = (
|
||||||
|
1 # In how many shard to split the model weight. Typically number gpu in a node.
|
||||||
|
)
|
||||||
|
dp_replicate: int = (
|
||||||
|
1 # How many times to replicate the model weight. Typically number of nodes.
|
||||||
|
)
|
||||||
|
tp_size: int = 1
|
||||||
|
selective_activation_checkpointing: bool = False
|
||||||
|
compile: bool = False
|
||||||
|
fsdp_type: str = "no_shard"
|
||||||
|
model_dtype: str = "bf16"
|
||||||
|
float8_recipe: str | None = None
|
||||||
|
float8_filter: str = r"layers\.[0-9]+\."
|
||||||
|
|
||||||
|
matmul_allow_tf32: bool = False
|
||||||
|
allow_bf16_reduced_precision_reduction: bool = True
|
||||||
|
detect_anomaly: bool = False
|
||||||
|
|
||||||
|
compile_cache_size_limit: int = 8
|
||||||
|
|
||||||
|
spawn_method: str = "forkserver"
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
# Use GNU openMP (GOMP) instead of Intel OpenMP [Intel Math Kernel Library (MKL)]
|
||||||
|
MKL_SERVICE_FORCE_INTEL: str = "GNU"
|
||||||
|
OMP_NUM_THREADS: str = "1"
|
||||||
|
MKL_NUM_THREADS: str = "1"
|
||||||
|
# faster intra-node collectives, seems to be a cluster specific flag
|
||||||
|
ENABLE_INTRA_NODE_COMM: str = "1"
|
||||||
|
# avoids OOMs with long context
|
||||||
|
TORCH_NCCL_AVOID_RECORD_STREAMS: str = "1"
|
||||||
|
# increasing NCCL timeout time before having some NCCL error 22 should give a 16s timeout
|
||||||
|
NCCL_IB_TIMEOUT: str = "22"
|
||||||
|
NCCL_DEBUG: str = "INFO"
|
||||||
|
TORCH_NCCL_ASYNC_ERROR_HANDLING: str = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_mesh(distributed_args: DistributedArgs):
|
||||||
|
tp_size = distributed_args.tp_size
|
||||||
|
dp_replicate = distributed_args.dp_replicate
|
||||||
|
dp_shard = distributed_args.dp_shard
|
||||||
|
|
||||||
|
assert (
|
||||||
|
dp_replicate * dp_shard * tp_size == get_world_size()
|
||||||
|
), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
|
||||||
|
|
||||||
|
dims = []
|
||||||
|
names = []
|
||||||
|
if dp_replicate >= 1:
|
||||||
|
dims.append(dp_replicate)
|
||||||
|
names.append("dp_replicate")
|
||||||
|
if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
|
||||||
|
dims.append(dp_shard)
|
||||||
|
names.append("dp_shard")
|
||||||
|
if tp_size > 1:
|
||||||
|
dims.append(tp_size)
|
||||||
|
names.append("tp")
|
||||||
|
dims = tuple(dims)
|
||||||
|
names = tuple(names)
|
||||||
|
|
||||||
|
return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
|
||||||
|
|
||||||
|
|
||||||
|
def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
|
||||||
|
tensor = torch.tensor(x).cuda()
|
||||||
|
dist.all_reduce(tensor, op=ReduceOp.MAX, group=mesh.get_group() if mesh else None)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
|
||||||
|
tensor = torch.tensor(x).cuda()
|
||||||
|
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def dist_mean_dict(x):
|
||||||
|
r = dict()
|
||||||
|
for k in x:
|
||||||
|
r[k] = dist_mean(x[k])
|
||||||
|
r[k] = r[k].item() if (r[k].dim() == 0) else r[k].tolist()
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_is_torch_run() -> bool:
|
||||||
|
return os.environ.get("LOCAL_RANK") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_is_slurm_job() -> bool:
|
||||||
|
return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_global_rank() -> int:
|
||||||
|
if get_is_torch_run():
|
||||||
|
return int(os.environ["RANK"])
|
||||||
|
elif get_is_slurm_job():
|
||||||
|
return int(os.environ["SLURM_PROCID"])
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_local_rank() -> int:
|
||||||
|
if get_is_torch_run():
|
||||||
|
return int(os.environ["LOCAL_RANK"])
|
||||||
|
elif get_is_slurm_job():
|
||||||
|
return int(os.environ["SLURM_LOCALID"])
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_world_size() -> int:
|
||||||
|
if get_is_torch_run():
|
||||||
|
return int(os.environ["WORLD_SIZE"])
|
||||||
|
elif get_is_slurm_job():
|
||||||
|
return int(os.environ["SLURM_NTASKS"])
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_is_master() -> bool:
|
||||||
|
return get_global_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_master_port(job_id: int) -> int:
|
||||||
|
if get_is_torch_run():
|
||||||
|
return int(os.environ["MASTER_PORT"])
|
||||||
|
else:
|
||||||
|
MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000)
|
||||||
|
rng = random.Random(job_id)
|
||||||
|
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_master_addr() -> str:
|
||||||
|
if get_is_torch_run():
|
||||||
|
return os.environ["MASTER_ADDR"]
|
||||||
|
elif get_is_slurm_job():
|
||||||
|
hostnames = subprocess.check_output(
|
||||||
|
["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
|
||||||
|
)
|
||||||
|
return hostnames.split()[0].decode("utf-8")
|
||||||
|
else:
|
||||||
|
return "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_env(env_args: EnvironmentArgs):
|
||||||
|
env_vars = env_args.model_dump()
|
||||||
|
|
||||||
|
# When using Triton, it attempts to locate prebuilt kernels in a cache
|
||||||
|
# located at ~/.triton/cache, but when that's backed by NFS this can fail
|
||||||
|
# with a "OSError: [Errno 116] Stale file handle" error. If we were to set
|
||||||
|
# it to a local directory it would belong to the first user who created it
|
||||||
|
# and it would fail for the job of any other successive user assigned to
|
||||||
|
# that machine. To avoid all this mess we use a temporary per-process cache.
|
||||||
|
triton_cache_dir = tempfile.mkdtemp()
|
||||||
|
atexit.register(shutil.rmtree, triton_cache_dir, ignore_errors=True)
|
||||||
|
env_vars["TRITON_CACHE_DIR"] = triton_cache_dir
|
||||||
|
|
||||||
|
# We change the tmp dir to /scratch in case it's slurm job
|
||||||
|
# This avoids filling up the host's usually limited tmpfs
|
||||||
|
# A full tmpfs leads to very slow creation of processes and weird bugs
|
||||||
|
if get_is_slurm_job():
|
||||||
|
new_tmp = f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}"
|
||||||
|
if os.path.exists(new_tmp):
|
||||||
|
env_vars["TMP_DIR"] = new_tmp
|
||||||
|
|
||||||
|
for name, value in env_vars.items():
|
||||||
|
if os.environ.get(name) != str(value):
|
||||||
|
os.environ[name] = str(value)
|
||||||
|
logger.warning(f"WARNING: Setting {name} to {value}")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_torch_distributed(dist_args):
|
||||||
|
"""
|
||||||
|
Handle single and multi-GPU / multi-node / SLURM jobs.
|
||||||
|
Initialize the following variables:
|
||||||
|
- global_rank
|
||||||
|
- world_size
|
||||||
|
"""
|
||||||
|
mp.set_start_method(dist_args.spawn_method)
|
||||||
|
with mp.Manager():
|
||||||
|
pass
|
||||||
|
|
||||||
|
local_rank = get_local_rank()
|
||||||
|
|
||||||
|
os.environ["RANK"] = str(get_global_rank())
|
||||||
|
os.environ["WORLD_SIZE"] = str(get_world_size())
|
||||||
|
os.environ["MASTER_ADDR"] = get_master_addr()
|
||||||
|
os.environ["MASTER_PORT"] = str(
|
||||||
|
get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1)))
|
||||||
|
)
|
||||||
|
|
||||||
|
if get_is_torch_run():
|
||||||
|
logger.info(f"Run launched with torchrun, local rank: {local_rank}")
|
||||||
|
elif get_is_slurm_job():
|
||||||
|
logger.info(f"Run launched with slurm, local rank: {local_rank}")
|
||||||
|
else:
|
||||||
|
logger.info("Single GPU job")
|
||||||
|
|
||||||
|
logger.info(f"ENV: {os.environ}")
|
||||||
|
|
||||||
|
# set GPU device
|
||||||
|
assert 0 <= local_rank < 8
|
||||||
|
if dist_args.matmul_allow_tf32:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
logger.warning(
|
||||||
|
f"WARNING: Setting torch.backends.matmul.allow_tf32 to True. This is faster but less accurate."
|
||||||
|
)
|
||||||
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||||
|
dist_args.allow_bf16_reduced_precision_reduction
|
||||||
|
)
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
torch.distributed.init_process_group(init_method="env://", backend="nccl")
|
||||||
|
torch.autograd.set_detect_anomaly(dist_args.detect_anomaly)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module(module, access_string):
|
||||||
|
names = access_string.split(sep=".")
|
||||||
|
return reduce(getattr, names, module)
|
||||||
|
|
||||||
|
|
||||||
|
def set_module(module, access_string, value):
|
||||||
|
names = access_string.split(sep=".")
|
||||||
|
parent = reduce(getattr, names[:-1], module)
|
||||||
|
setattr(parent, names[-1], value)
|
||||||
|
|
||||||
|
|
||||||
|
def default_fsdp_grouping_plan(n_layers: int) -> List[Tuple[str, bool]]:
|
||||||
|
return [(f"layers.{i}", i < n_layers - 1) for i in range(n_layers)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_policy(no_recompute_ops=None):
|
||||||
|
no_recompute_ops = no_recompute_ops or default_no_recompute_ops
|
||||||
|
|
||||||
|
def default_policy(ctx, func, *args, **kwargs):
|
||||||
|
return (
|
||||||
|
CheckpointPolicy.MUST_SAVE
|
||||||
|
if func in no_recompute_ops
|
||||||
|
else CheckpointPolicy.PREFER_RECOMPUTE
|
||||||
|
)
|
||||||
|
|
||||||
|
return default_policy
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def check_model_value_range(
|
||||||
|
model: torch.nn.Module, range: float = 1e3, std: float = 1e3
|
||||||
|
):
|
||||||
|
for name, param in chain(model.named_parameters(), model.named_buffers()):
|
||||||
|
if isinstance(param, DTensor):
|
||||||
|
param = param.to_local()
|
||||||
|
|
||||||
|
if param.numel() == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Model parameter {name} is empty, probably because of FSDP sharding"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if torch.isnan(param).any() or torch.isinf(param).any():
|
||||||
|
logger.warning(f"Model parameter {name} contains NaN or Inf")
|
||||||
|
|
||||||
|
param_range = param.max() - param.min()
|
||||||
|
param_std = param.std()
|
||||||
|
if param_range > range:
|
||||||
|
logger.warning(
|
||||||
|
f"Model parameter {name} has a suspiciously large range ({param_range}): please check initialization and init_weights is defined and called"
|
||||||
|
)
|
||||||
|
if param_std > std:
|
||||||
|
logger.warning(
|
||||||
|
f"Model parameter {name} has a suspiciously large standard deviation ({param_std}): please check initialization and init_weights is defined and called"
|
||||||
|
)
|
||||||
|
if (param == 0).all():
|
||||||
|
logger.warning(
|
||||||
|
f"Model parameter {name} is all zeros: it might be because of a missing initialization"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_signal_handler(callable):
|
||||||
|
"""
|
||||||
|
Handle signals sent by SLURM for time limit / pre-emption.
|
||||||
|
"""
|
||||||
|
signal.signal(signal.SIGUSR2, callable)
|
||||||
|
logger.warning("Signal handler installed.")
|
||||||
|
|
||||||
|
|
||||||
|
def requeue_slurm_job():
|
||||||
|
prod_id = int(os.environ["SLURM_PROCID"])
|
||||||
|
logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
|
||||||
|
if prod_id == 0 and os.environ.get("LAUNCH_WITH", "") != "DORA":
|
||||||
|
logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"])
|
||||||
|
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
|
||||||
|
else:
|
||||||
|
logger.warning("Not the master process, no need to requeue.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def clean_env():
|
||||||
|
distrib_names = (
|
||||||
|
"MASTER_ADDR",
|
||||||
|
"MASTER_PORT",
|
||||||
|
"RANK",
|
||||||
|
"WORLD_SIZE",
|
||||||
|
"LOCAL_RANK",
|
||||||
|
"LOCAL_WORLD_SIZE",
|
||||||
|
"TORCHELASTIC_RUN_ID",
|
||||||
|
"DORA_FORCE_DISTRIB",
|
||||||
|
)
|
||||||
|
cluster_env = {
|
||||||
|
x: os.environ.pop(x)
|
||||||
|
for x in os.environ
|
||||||
|
if x.startswith(
|
||||||
|
("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_", "WANDB_")
|
||||||
|
)
|
||||||
|
or x in distrib_names
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
os.environ.update(cluster_env)
|
||||||
|
|
||||||
|
|
||||||
|
def parallelize_model(
|
||||||
|
model,
|
||||||
|
device_mesh,
|
||||||
|
model_args,
|
||||||
|
distributed_args: DistributedArgs,
|
||||||
|
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
|
||||||
|
tp_parallelize=None,
|
||||||
|
no_recompute_ops=None,
|
||||||
|
):
|
||||||
|
if distributed_args.tp_size > 1:
|
||||||
|
assert (
|
||||||
|
distributed_args.fsdp_type == "full_shard"
|
||||||
|
), "Only full shard is supported for TP parallelism"
|
||||||
|
assert tp_parallelize is not None, "TP plan is required for TP parallelism"
|
||||||
|
assert (
|
||||||
|
distributed_args.compile == False
|
||||||
|
), "Compile is not supported for TP parallelism"
|
||||||
|
|
||||||
|
tp_parallelize(model, device_mesh["tp"], model_args, distributed_args)
|
||||||
|
|
||||||
|
if distributed_args.float8_recipe is not None:
|
||||||
|
if distributed_args.tp_size > 1:
|
||||||
|
raise RuntimeError("float8 is incompatible with tensor-parallelism for now")
|
||||||
|
model = convert_linears_to_fp8(
|
||||||
|
model, distributed_args.float8_recipe, distributed_args.float8_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
||||||
|
distributed_args.model_dtype
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
distributed_args.fsdp_type == "full_shard"
|
||||||
|
or distributed_args.fsdp_type == "no_shard"
|
||||||
|
):
|
||||||
|
if distributed_args.fsdp_type == "no_shard":
|
||||||
|
assert (
|
||||||
|
distributed_args.dp_shard == 1
|
||||||
|
), "dp_shard must be 1 for no_shard fsdp_type"
|
||||||
|
assert (
|
||||||
|
device_mesh["dp_shard"].size() == 1
|
||||||
|
), "dp_shard must be 1 for no_shard fsdp_type"
|
||||||
|
|
||||||
|
fsdp_config = dict(
|
||||||
|
mp_policy=(
|
||||||
|
MixedPrecisionPolicy(
|
||||||
|
param_dtype=param_dtype,
|
||||||
|
reduce_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
mesh=(
|
||||||
|
device_mesh["dp_replicate", "dp_shard"]
|
||||||
|
if distributed_args.dp_shard > 1
|
||||||
|
or distributed_args.fsdp_type == "no_shard"
|
||||||
|
else device_mesh["dp_replicate"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if fsdp_grouping_plan is None:
|
||||||
|
# Assume that the model has list of layers and group around it
|
||||||
|
fsdp_grouping_plan = default_fsdp_grouping_plan(len(model.layers))
|
||||||
|
|
||||||
|
for path, reshard_after_forward in fsdp_grouping_plan:
|
||||||
|
module = get_module(model, path)
|
||||||
|
set_module(
|
||||||
|
model,
|
||||||
|
path,
|
||||||
|
fully_shard(
|
||||||
|
module, **fsdp_config, reshard_after_forward=reshard_after_forward
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
model = fully_shard(model, **fsdp_config, reshard_after_forward=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
||||||
|
|
||||||
|
if distributed_args.selective_activation_checkpointing:
|
||||||
|
model = checkpoint_wrapper(
|
||||||
|
model,
|
||||||
|
context_fn=partial(
|
||||||
|
create_selective_checkpoint_contexts,
|
||||||
|
get_default_policy(no_recompute_ops),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if distributed_args.compile:
|
||||||
|
torch._dynamo.config.cache_size_limit = (
|
||||||
|
distributed_args.compile_cache_size_limit
|
||||||
|
)
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
return model
|
36
bytelatent/entropy_model.py
Normal file
36
bytelatent/entropy_model.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
||||||
|
|
||||||
|
|
||||||
|
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
|
||||||
|
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
||||||
|
reloaded = json.loads(fr.read())
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
model_params = reloaded["model"]
|
||||||
|
entropy_model = LMTransformer(
|
||||||
|
LMTransformerArgs(
|
||||||
|
dim=model_params["dim"],
|
||||||
|
n_layers=model_params["n_layers"],
|
||||||
|
n_heads=model_params["n_heads"],
|
||||||
|
max_seqlen=model_params["max_length"],
|
||||||
|
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
||||||
|
vocab_size=model_params["vocab_size"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
entropy_model.load_state_dict(
|
||||||
|
torch.load(state_dict_path, map_location=device), strict=False
|
||||||
|
)
|
||||||
|
entropy_model.to(device)
|
||||||
|
entropy_model = entropy_model.eval()
|
||||||
|
# no grads for the model:
|
||||||
|
for param in entropy_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
return entropy_model
|
152
bytelatent/float8.py
Normal file
152
bytelatent/float8.py
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# avoid division by zero when calculating scale
|
||||||
|
EPS = 1e-12
|
||||||
|
|
||||||
|
|
||||||
|
def scale(t, amax_t, dtype_t):
|
||||||
|
min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
|
||||||
|
scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
|
||||||
|
t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
|
||||||
|
return t_fp8, scale_t
|
||||||
|
|
||||||
|
|
||||||
|
def matmul(
|
||||||
|
first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias
|
||||||
|
):
|
||||||
|
first_fp8, scale_first = scale(first, amax_first, dtype_first)
|
||||||
|
second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t)
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
first_fp8,
|
||||||
|
second_t_fp8.t(),
|
||||||
|
scale_a=scale_first,
|
||||||
|
scale_b=scale_second_t.t(),
|
||||||
|
bias=bias,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=True,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@torch._dynamo.allow_in_graph
|
||||||
|
class Fp8LinearFn(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, a, b_t, bias):
|
||||||
|
amax_a = a.abs().amax(dim=-1, keepdim=True)
|
||||||
|
amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
|
||||||
|
out = matmul(
|
||||||
|
a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.a_requires_grad = a.requires_grad
|
||||||
|
ctx.b_requires_grad = b_t.requires_grad
|
||||||
|
ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
|
||||||
|
|
||||||
|
ctx.save_for_backward(a, b_t, amax_b_t.max())
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
a, b_t, amax_b = ctx.saved_tensors
|
||||||
|
|
||||||
|
if ctx.a_requires_grad:
|
||||||
|
b = b_t.t().contiguous()
|
||||||
|
amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
|
||||||
|
amax_b = amax_b.repeat(b.shape[0], 1)
|
||||||
|
grad_a = matmul(
|
||||||
|
grad_out,
|
||||||
|
amax_grad_out,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
b,
|
||||||
|
amax_b,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_a = None
|
||||||
|
if ctx.b_requires_grad:
|
||||||
|
grad_b = grad_out.t() @ a
|
||||||
|
else:
|
||||||
|
grad_b = None
|
||||||
|
if ctx.bias_requires_grad:
|
||||||
|
grad_bias = grad_out.sum(dim=0)
|
||||||
|
else:
|
||||||
|
grad_bias = None
|
||||||
|
|
||||||
|
return grad_a, grad_b, grad_bias
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8Linear(torch.nn.Linear):
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
|
||||||
|
out = out.unflatten(0, input.shape[:-1])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def named_replace(
|
||||||
|
fn: Callable[[torch.nn.Module, str], torch.nn.Module],
|
||||||
|
module: torch.nn.Module,
|
||||||
|
name="",
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
for child_name, child_module in list(module.named_children()):
|
||||||
|
full_name = f"{name}.{child_name}" if name else child_name
|
||||||
|
new_child_module = named_replace(fn, child_module, full_name)
|
||||||
|
setattr(module, child_name, new_child_module)
|
||||||
|
module = fn(module, name)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def convert_linears_to_fp8(
|
||||||
|
root_module: torch.nn.Module, recipe: str, filter: str
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
if recipe not in ["rowwise"]:
|
||||||
|
raise RuntimeError(f"Unknown float8 recipe {recipe!r}")
|
||||||
|
|
||||||
|
if recipe == "rowwise" and torch.__version__ < "2.5":
|
||||||
|
# We need https://github.com/pytorch/pytorch/pull/134781.
|
||||||
|
warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0")
|
||||||
|
|
||||||
|
# Multi-kernel makes Inductor auto-tune between a regular "streaming"-based
|
||||||
|
# reduction kernel and a "persistent" reduction kernel. Since fp8 has some
|
||||||
|
# multi-pass steps (e.g., first get amax, then scale), persistent kernels
|
||||||
|
# should perform better.
|
||||||
|
torch._inductor.config.triton.multi_kernel = 1
|
||||||
|
|
||||||
|
filter_re = re.compile(filter)
|
||||||
|
|
||||||
|
def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
|
||||||
|
if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
|
||||||
|
return module
|
||||||
|
if type(module) == torch.nn.Linear:
|
||||||
|
if recipe == "rowwise":
|
||||||
|
new_module = Fp8Linear(
|
||||||
|
in_features=module.in_features,
|
||||||
|
out_features=module.out_features,
|
||||||
|
bias=module.bias is not None,
|
||||||
|
dtype=module.weight.dtype,
|
||||||
|
device=module.weight.device,
|
||||||
|
)
|
||||||
|
new_module.weight = module.weight
|
||||||
|
new_module.bias = module.bias
|
||||||
|
else:
|
||||||
|
assert False, recipe
|
||||||
|
else:
|
||||||
|
assert False, str(type(module))
|
||||||
|
return new_module
|
||||||
|
|
||||||
|
out = named_replace(replace, root_module)
|
||||||
|
|
||||||
|
# Force re-compile everything
|
||||||
|
torch._dynamo.reset_code_caches()
|
||||||
|
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
||||||
|
|
||||||
|
reset_cudagraph_trees()
|
||||||
|
|
||||||
|
return out
|
129
bytelatent/logger.py
Normal file
129
bytelatent/logger.py
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from bytelatent.distributed import get_global_rank, get_is_slurm_job
|
||||||
|
|
||||||
|
|
||||||
|
class LogFormatter(logging.Formatter):
|
||||||
|
"""
|
||||||
|
Custom logger for distributed jobs, displaying rank
|
||||||
|
and preserving indent from the custom prefix format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.rank = get_global_rank()
|
||||||
|
self.show_rank = not get_is_slurm_job() # srun has --label
|
||||||
|
|
||||||
|
def formatTime(self, record):
|
||||||
|
subsecond, seconds = math.modf(record.created)
|
||||||
|
curr_date = (
|
||||||
|
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
|
||||||
|
+ f".{int(subsecond * 1_000_000):06d}"
|
||||||
|
)
|
||||||
|
delta = timedelta(seconds=round(record.created - self.start_time))
|
||||||
|
return f"{curr_date} - {delta}"
|
||||||
|
|
||||||
|
def formatPrefix(self, record):
|
||||||
|
fmt_time = self.formatTime(record)
|
||||||
|
if self.show_rank:
|
||||||
|
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
|
||||||
|
else:
|
||||||
|
return f"{record.levelname:<7} {fmt_time} - "
|
||||||
|
|
||||||
|
def formatMessage(self, record, indent: str):
|
||||||
|
content = record.getMessage()
|
||||||
|
content = content.replace("\n", "\n" + indent)
|
||||||
|
# Exception handling as in the default formatter, albeit with indenting
|
||||||
|
# according to our custom prefix
|
||||||
|
if record.exc_info:
|
||||||
|
# Cache the traceback text to avoid converting it multiple times
|
||||||
|
# (it's constant anyway)
|
||||||
|
if not record.exc_text:
|
||||||
|
record.exc_text = self.formatException(record.exc_info)
|
||||||
|
if record.exc_text:
|
||||||
|
if content[-1:] != "\n":
|
||||||
|
content = content + "\n" + indent
|
||||||
|
content = content + indent.join(
|
||||||
|
[l + "\n" for l in record.exc_text.splitlines()]
|
||||||
|
)
|
||||||
|
if content[-1:] == "\n":
|
||||||
|
content = content[:-1]
|
||||||
|
if record.stack_info:
|
||||||
|
if content[-1:] != "\n":
|
||||||
|
content = content + "\n" + indent
|
||||||
|
stack_text = self.formatStack(record.stack_info)
|
||||||
|
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
|
||||||
|
if content[-1:] == "\n":
|
||||||
|
content = content[:-1]
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
prefix = self.formatPrefix(record)
|
||||||
|
indent = " " * len(prefix)
|
||||||
|
content = self.formatMessage(record, indent)
|
||||||
|
return prefix + content
|
||||||
|
|
||||||
|
|
||||||
|
def set_root_log_level(log_level: str):
|
||||||
|
logger = logging.getLogger()
|
||||||
|
level: int | str = log_level.upper()
|
||||||
|
try:
|
||||||
|
level = int(log_level)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
logger.setLevel(level) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
def init_logger(
|
||||||
|
log_file: str | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
level: str = "NOTSET",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file: A file name to save file logs to.
|
||||||
|
name: The name of the logger to configure, by default the root logger.
|
||||||
|
level: The logging level to use.
|
||||||
|
"""
|
||||||
|
set_root_log_level(level)
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
# stdout: everything
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
stdout_handler.setLevel(logging.NOTSET)
|
||||||
|
stdout_handler.setFormatter(LogFormatter())
|
||||||
|
|
||||||
|
# stderr: warnings / errors and above
|
||||||
|
stderr_handler = logging.StreamHandler(sys.stderr)
|
||||||
|
stderr_handler.setLevel(logging.WARNING)
|
||||||
|
stderr_handler.setFormatter(LogFormatter())
|
||||||
|
|
||||||
|
# set stream handlers
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.handlers.append(stdout_handler)
|
||||||
|
logger.handlers.append(stderr_handler)
|
||||||
|
|
||||||
|
if log_file is not None and get_global_rank() == 0:
|
||||||
|
# build file handler
|
||||||
|
file_handler = logging.FileHandler(log_file, "a")
|
||||||
|
file_handler.setLevel(logging.NOTSET)
|
||||||
|
file_handler.setFormatter(LogFormatter())
|
||||||
|
# update logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.addHandler(file_handler)
|
232
bytelatent/metrics.py
Normal file
232
bytelatent/metrics.py
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections import namedtuple
|
||||||
|
from dataclasses import asdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import wandb
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.distributed import get_is_master
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class WandbArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
job_type: str | None = None
|
||||||
|
dir: str | None = None
|
||||||
|
project: str | None = None
|
||||||
|
entity: str | None = None
|
||||||
|
tags: list | None = None
|
||||||
|
group: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
notes: str | None = None
|
||||||
|
config_exclude_keys: list[str] | None = None
|
||||||
|
config_include_keys: list[str] | None = None
|
||||||
|
anonymous: str | None = None
|
||||||
|
mode: str | None = None
|
||||||
|
allow_val_change: bool | None = None
|
||||||
|
resume: Union[bool, str] | None = None
|
||||||
|
force: bool | None = None
|
||||||
|
tensorboard: bool | None = None
|
||||||
|
sync_tensorboard: bool | None = None
|
||||||
|
monitor_gym: bool | None = None
|
||||||
|
save_code: bool | None = None
|
||||||
|
id: str | None = None
|
||||||
|
fork_from: str | None = None
|
||||||
|
resume_from: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
freq: int = 10 # Log every freq optimizer steps
|
||||||
|
acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
|
||||||
|
|
||||||
|
wandb: WandbArgs | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger:
|
||||||
|
def __init__(self, outdir: Path, args: Any | None = None):
|
||||||
|
self.outdir = outdir
|
||||||
|
self.jsonl_writer = None
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
if self.jsonl_writer is None:
|
||||||
|
self.jsonl_writer = open(self.outdir, "a")
|
||||||
|
if (
|
||||||
|
self.args is not None
|
||||||
|
and self.args.logging.wandb is not None
|
||||||
|
and get_is_master()
|
||||||
|
):
|
||||||
|
run = wandb.init(
|
||||||
|
config=asdict(self.args),
|
||||||
|
**asdict(self.args.logging.wandb),
|
||||||
|
)
|
||||||
|
|
||||||
|
def log(self, metrics: dict[str, Any]):
|
||||||
|
if (
|
||||||
|
self.args is not None
|
||||||
|
and self.args.logging.wandb is not None
|
||||||
|
and (wandb.run is not None)
|
||||||
|
):
|
||||||
|
wandb.log(metrics, step=metrics["global_step"])
|
||||||
|
|
||||||
|
metrics.update({"created_at": datetime.now(timezone.utc).isoformat()})
|
||||||
|
print(json.dumps(metrics), file=self.jsonl_writer, flush=True)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.jsonl_writer is not None:
|
||||||
|
self.jsonl_writer.close()
|
||||||
|
self.jsonl_writer = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.open()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
GPUMemStats = namedtuple(
|
||||||
|
"GPUMemStats",
|
||||||
|
[
|
||||||
|
"max_active_gib",
|
||||||
|
"max_active_pct",
|
||||||
|
"max_reserved_gib",
|
||||||
|
"max_reserved_pct",
|
||||||
|
"num_alloc_retries",
|
||||||
|
"num_ooms",
|
||||||
|
"power_draw",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPUMemoryMonitor:
|
||||||
|
"""
|
||||||
|
Class to monitor GPU memory usage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: str = "cuda:0"):
|
||||||
|
self.device = torch.device(device) # device object
|
||||||
|
self.device_name = torch.cuda.get_device_name(self.device)
|
||||||
|
self.device_index = torch.cuda.current_device()
|
||||||
|
self.device_capacity = torch.cuda.get_device_properties(
|
||||||
|
self.device
|
||||||
|
).total_memory
|
||||||
|
self.device_capacity_gib = self._to_gib(self.device_capacity)
|
||||||
|
|
||||||
|
# reset stats, clear cache
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _to_gib(self, memory_in_bytes):
|
||||||
|
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
|
||||||
|
_gib_in_bytes = 1024 * 1024 * 1024
|
||||||
|
memory_in_gib = memory_in_bytes / _gib_in_bytes
|
||||||
|
return memory_in_gib
|
||||||
|
|
||||||
|
def _to_pct(self, memory):
|
||||||
|
return 100 * memory / self.device_capacity
|
||||||
|
|
||||||
|
def get_peak_stats(self):
|
||||||
|
cuda_info = torch.cuda.memory_stats(self.device)
|
||||||
|
|
||||||
|
max_active = cuda_info["active_bytes.all.peak"]
|
||||||
|
max_active_gib = self._to_gib(max_active)
|
||||||
|
max_active_pct = self._to_pct(max_active)
|
||||||
|
|
||||||
|
max_reserved = cuda_info["reserved_bytes.all.peak"]
|
||||||
|
max_reserved_gib = self._to_gib(max_reserved)
|
||||||
|
max_reserved_pct = self._to_pct(max_reserved)
|
||||||
|
|
||||||
|
num_retries = cuda_info["num_alloc_retries"]
|
||||||
|
num_ooms = cuda_info["num_ooms"]
|
||||||
|
power_draw = torch.cuda.power_draw()
|
||||||
|
|
||||||
|
if num_retries > 0:
|
||||||
|
logger.warning(f"{num_retries} CUDA memory allocation retries.")
|
||||||
|
if num_ooms > 0:
|
||||||
|
logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
|
||||||
|
|
||||||
|
return GPUMemStats(
|
||||||
|
max_active_gib,
|
||||||
|
max_active_pct,
|
||||||
|
max_reserved_gib,
|
||||||
|
max_reserved_pct,
|
||||||
|
num_retries,
|
||||||
|
num_ooms,
|
||||||
|
power_draw,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_peak_stats(self):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.reset_accumulated_memory_stats()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
mem_stats = self.get_peak_stats()
|
||||||
|
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, "
|
||||||
|
display_str += (
|
||||||
|
f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak"
|
||||||
|
)
|
||||||
|
return f"{display_str}"
|
||||||
|
|
||||||
|
|
||||||
|
def upload_train_to_wandb(
|
||||||
|
ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True
|
||||||
|
):
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml")
|
||||||
|
cfg = OmegaConf.to_container(cfg)
|
||||||
|
|
||||||
|
if train:
|
||||||
|
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
|
||||||
|
|
||||||
|
with open(Path(ckpt_dir) / "metrics.jsonl") as f:
|
||||||
|
for l in f:
|
||||||
|
m = json.loads(l)
|
||||||
|
wandb.log(m, step=m["global_step"])
|
||||||
|
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
if eval:
|
||||||
|
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
|
||||||
|
|
||||||
|
with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f:
|
||||||
|
for l in f:
|
||||||
|
m = json.loads(l)
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"evals/{name.replace('/','.')}": value
|
||||||
|
for name, value in m.items()
|
||||||
|
if "/" in name
|
||||||
|
},
|
||||||
|
step=m["global_step"],
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_params(model: nn.Module) -> int:
|
||||||
|
"""
|
||||||
|
Get the total model params
|
||||||
|
Args : only_trainable: whether to only count trainable params
|
||||||
|
"""
|
||||||
|
numel = {n: p.numel() for n, p in model.named_parameters()}
|
||||||
|
return sum(numel.values())
|
1
bytelatent/model/__init__.py
Normal file
1
bytelatent/model/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
1064
bytelatent/model/blt.py
Normal file
1064
bytelatent/model/blt.py
Normal file
File diff suppressed because it is too large
Load diff
356
bytelatent/model/local_models.py
Normal file
356
bytelatent/model/local_models.py
Normal file
|
@ -0,0 +1,356 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
|
from xformers.ops import AttentionBias
|
||||||
|
|
||||||
|
from bytelatent.base_transformer import (
|
||||||
|
InitStdFactor,
|
||||||
|
RMSNorm,
|
||||||
|
RotaryEmbedding,
|
||||||
|
TransformerBlock,
|
||||||
|
)
|
||||||
|
from bytelatent.model.transformer import CrossAttention
|
||||||
|
from bytelatent.model.utils import create_causal_mask, downsample
|
||||||
|
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class LocalModelBase(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = args.dim
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.vocab_size = args.vocab_size + args.pm_size
|
||||||
|
self.patch_size = args.patch_size
|
||||||
|
|
||||||
|
self.efficient_attn = args.efficient_attn
|
||||||
|
self.sliding_window = args.sliding_window
|
||||||
|
self.use_rope = args.use_rope
|
||||||
|
self.init_std_factor = args.init_std_factor
|
||||||
|
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
|
||||||
|
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
|
||||||
|
self.cross_attn_k = getattr(args, "cross_attn_k", None)
|
||||||
|
|
||||||
|
self.boe_id = BOE_ID
|
||||||
|
|
||||||
|
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[TransformerBlock(args) for _ in range(args.n_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
|
||||||
|
if not self.use_rope:
|
||||||
|
self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
|
||||||
|
else:
|
||||||
|
self.rope = RotaryEmbedding(
|
||||||
|
theta=args.rope_theta,
|
||||||
|
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||||
|
max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
|
||||||
|
)
|
||||||
|
self.pos_embeddings = None
|
||||||
|
|
||||||
|
self.token_embedding_projection = (
|
||||||
|
nn.Linear(args.dim_token_emb, args.dim, bias=False)
|
||||||
|
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.patch_embedding_projection = self._create_patch_projection(args)
|
||||||
|
|
||||||
|
def _should_create_patch_projection(self, args):
|
||||||
|
dimension_mismatch = (
|
||||||
|
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check cross attention conditions
|
||||||
|
cross_attn_conditions = (
|
||||||
|
hasattr(args, "cross_attn_encoder")
|
||||||
|
and args.cross_attn_encoder
|
||||||
|
and getattr(args, "cross_attn_init_by_pooling")
|
||||||
|
) or (
|
||||||
|
hasattr(args, "cross_attn_decoder")
|
||||||
|
and args.cross_attn_decoder
|
||||||
|
and getattr(args, "cross_attn_init_by_pooling")
|
||||||
|
)
|
||||||
|
|
||||||
|
return dimension_mismatch or cross_attn_conditions
|
||||||
|
|
||||||
|
def _create_patch_projection(self, args):
|
||||||
|
if not self._should_create_patch_projection(args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
output_dim = args.dim_token_emb * (self.cross_attn_k or 1)
|
||||||
|
|
||||||
|
return nn.Linear(
|
||||||
|
in_features=args.dim_patch_emb,
|
||||||
|
out_features=output_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_embedding(self, tokens, embeds):
|
||||||
|
if embeds is not None:
|
||||||
|
return embeds
|
||||||
|
else:
|
||||||
|
return self.tok_embeddings(tokens)
|
||||||
|
|
||||||
|
def init_weights(self, init_std=None):
|
||||||
|
self.rope.reset_parameters()
|
||||||
|
|
||||||
|
init_std = init_std or (self.dim ** (-0.5))
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.tok_embeddings.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
if self.pos_embeddings is not None:
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.pos_embeddings.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
for depth, layer in enumerate(self.layers):
|
||||||
|
factor = {
|
||||||
|
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
||||||
|
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
||||||
|
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
||||||
|
InitStdFactor.DISABLED: 1.0,
|
||||||
|
}[self.init_std_factor]
|
||||||
|
|
||||||
|
layer.init_weights(init_std, factor)
|
||||||
|
|
||||||
|
if self.token_embedding_projection is not None:
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.token_embedding_projection.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.patch_embedding_projection is not None:
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.patch_embedding_projection.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self, "output"):
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.output.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cross_attn_layers is not None:
|
||||||
|
for depth, layer in enumerate(self.cross_attn_layers):
|
||||||
|
factor = {
|
||||||
|
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
||||||
|
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
||||||
|
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
||||||
|
InitStdFactor.DISABLED: 1.0,
|
||||||
|
}[self.init_std_factor]
|
||||||
|
|
||||||
|
layer.init_weights(init_std, factor)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEncoder(LocalModelBase):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__(args)
|
||||||
|
self.output_proj = (
|
||||||
|
args.patching_mode in ["entropy", "probmax"]
|
||||||
|
) and args.entropy_model_checkpoint_dir is None
|
||||||
|
|
||||||
|
self.apply_transformer = args.use_local_encoder_transformer
|
||||||
|
self.downsampling_by_pooling = args.downsampling_by_pooling
|
||||||
|
self.patch_only = args.patch_only_encoder
|
||||||
|
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
|
||||||
|
self.cross_attn_encoder = args.cross_attn_encoder
|
||||||
|
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
|
||||||
|
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
||||||
|
self.cross_attn_nheads = args.cross_attn_nheads
|
||||||
|
|
||||||
|
if self.cross_attn_encoder:
|
||||||
|
self.cross_attn_layers = torch.nn.ModuleList()
|
||||||
|
layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
|
||||||
|
for _ in range(layers_to_add):
|
||||||
|
self.cross_attn_layers.append(
|
||||||
|
CrossAttention(
|
||||||
|
dim=self.dim,
|
||||||
|
head_dim=self.dim // self.cross_attn_nheads,
|
||||||
|
n_heads=self.cross_attn_nheads,
|
||||||
|
n_kv_heads=self.cross_attn_nheads,
|
||||||
|
norm_eps=args.norm_eps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_embedding(self, tokens, embeds):
|
||||||
|
if embeds is not None:
|
||||||
|
assert (
|
||||||
|
self.expects_hash_embeddings
|
||||||
|
), "Not expecting embeddings to be passed."
|
||||||
|
return embeds
|
||||||
|
else:
|
||||||
|
return self.tok_embeddings(tokens)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
embeds: Optional[torch.Tensor] = None,
|
||||||
|
patch_embeds: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
|
||||||
|
cross_mask: Optional[torch.Tensor] = None,
|
||||||
|
num_patches: Optional[int] = None,
|
||||||
|
patch_ids: Optional[torch.Tensor] = None,
|
||||||
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
||||||
|
):
|
||||||
|
""" """
|
||||||
|
bs, seqlen = tokens.shape
|
||||||
|
if mask is None:
|
||||||
|
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
|
||||||
|
|
||||||
|
h = self.apply_embedding(tokens, embeds)
|
||||||
|
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
|
||||||
|
|
||||||
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
|
||||||
|
# check if cross attention should be applied to either all layer or only the last layer
|
||||||
|
if self.cross_attn_encoder and (
|
||||||
|
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
|
||||||
|
):
|
||||||
|
patch_embeds = self.apply_cross_attention(
|
||||||
|
h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
h_residual = patch_embeds if self.cross_attn_encoder else None
|
||||||
|
return (h, h_residual), cache
|
||||||
|
|
||||||
|
def apply_cross_attention(
|
||||||
|
self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask
|
||||||
|
):
|
||||||
|
# apply pooling and project
|
||||||
|
if self.cross_attn_init_by_pooling and patch_embeds is None:
|
||||||
|
patch_embeds = downsample(
|
||||||
|
h,
|
||||||
|
num_patches,
|
||||||
|
patch_ids=patch_ids,
|
||||||
|
downsampling_by_pooling=self.downsampling_by_pooling,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
)
|
||||||
|
if self.patch_embedding_projection is not None:
|
||||||
|
patch_embeds = self.patch_embedding_projection(patch_embeds)
|
||||||
|
patch_embeds = patch_embeds.reshape(
|
||||||
|
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0
|
||||||
|
patch_embeds_cross = self.cross_attn_layers[layer_idx](
|
||||||
|
x=patch_embeds,
|
||||||
|
kv=h,
|
||||||
|
mask=cross_mask,
|
||||||
|
)
|
||||||
|
patch_embeds += patch_embeds_cross
|
||||||
|
return patch_embeds
|
||||||
|
|
||||||
|
|
||||||
|
class LocalDecoder(LocalModelBase):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__(args)
|
||||||
|
|
||||||
|
# Model configuration flags
|
||||||
|
self.patch_only = args.patch_only_decoder
|
||||||
|
self.expects_embeddings = args.share_encoder_decoder_emb
|
||||||
|
self.cross_attn_decoder = args.cross_attn_decoder
|
||||||
|
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
|
||||||
|
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
||||||
|
self.cross_attn_nheads = args.cross_attn_nheads
|
||||||
|
|
||||||
|
if self.cross_attn_decoder:
|
||||||
|
self.cross_attn_layers = torch.nn.ModuleList()
|
||||||
|
layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
|
||||||
|
for _ in range(layers_to_add):
|
||||||
|
self.cross_attn_layers.append(
|
||||||
|
CrossAttention(
|
||||||
|
dim=self.dim,
|
||||||
|
head_dim=self.dim // self.cross_attn_nheads,
|
||||||
|
n_heads=self.cross_attn_nheads,
|
||||||
|
n_kv_heads=self.cross_attn_nheads,
|
||||||
|
norm_eps=args.norm_eps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output = nn.Linear(
|
||||||
|
self.dim,
|
||||||
|
args.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
embeds: Optional[torch.Tensor],
|
||||||
|
patch_embeds: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
|
||||||
|
cross_mask: Optional[torch.Tensor] = None,
|
||||||
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
||||||
|
):
|
||||||
|
bs, seqlen = tokens.shape
|
||||||
|
assert embeds is not None, "Embeddings must be provided"
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
|
||||||
|
|
||||||
|
h = embeds
|
||||||
|
|
||||||
|
if self.patch_embedding_projection is not None:
|
||||||
|
assert patch_embeds is not None, "Patch embeddings must be passed."
|
||||||
|
patch_embeds = self.patch_embedding_projection(patch_embeds)
|
||||||
|
if self.cross_attn_k is not None:
|
||||||
|
patch_embeds = patch_embeds.reshape(
|
||||||
|
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
|
||||||
|
)
|
||||||
|
|
||||||
|
if patch_embeds is not None and not self.cross_attn_decoder:
|
||||||
|
h = h + patch_embeds
|
||||||
|
|
||||||
|
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
|
||||||
|
|
||||||
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
if self.cross_attn_decoder and (
|
||||||
|
i == 0 or self.cross_attn_all_layers_decoder
|
||||||
|
):
|
||||||
|
# Use cross attention to extract info from patch_embeds into h
|
||||||
|
h_cross = self.cross_attn_layers[i](
|
||||||
|
x=h,
|
||||||
|
kv=patch_embeds,
|
||||||
|
mask=cross_mask,
|
||||||
|
)
|
||||||
|
h = h + h_cross
|
||||||
|
|
||||||
|
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
|
||||||
|
|
||||||
|
h_preds = self.norm(h)
|
||||||
|
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
|
||||||
|
h_preds = self.output(h_preds)
|
||||||
|
h_preds = h_preds.float()
|
||||||
|
return h_preds, cache
|
199
bytelatent/model/transformer.py
Normal file
199
bytelatent/model/transformer.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
|
from xformers.ops import AttentionBias
|
||||||
|
|
||||||
|
from bytelatent.base_transformer import (
|
||||||
|
BaseTransformer,
|
||||||
|
RMSNorm,
|
||||||
|
flex_attention_comp,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
from bytelatent.model.utils import create_causal_mask
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
CrossAttention block to attend to the encoder states from the decoder.
|
||||||
|
Rope is not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
head_dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.heads_per_group = self.n_heads // self.n_kv_heads
|
||||||
|
|
||||||
|
self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
self.wq = nn.Linear(
|
||||||
|
dim,
|
||||||
|
n_heads * head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.wk = nn.Linear(
|
||||||
|
dim,
|
||||||
|
n_kv_heads * head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.wv = nn.Linear(
|
||||||
|
dim,
|
||||||
|
n_kv_heads * head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.wo = nn.Linear(
|
||||||
|
n_heads * head_dim,
|
||||||
|
dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# B S D
|
||||||
|
bsz, seq_len, _ = x.shape
|
||||||
|
_, slen_kv, _ = kv.shape
|
||||||
|
x = self.cross_attn_norm_q(x)
|
||||||
|
kv = self.cross_attn_norm_kv(kv)
|
||||||
|
|
||||||
|
xq = self.wq(x)
|
||||||
|
xk = self.wk(kv)
|
||||||
|
xv = self.wv(kv)
|
||||||
|
|
||||||
|
output_shape = xq.shape
|
||||||
|
# B S D -> B S H D
|
||||||
|
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
xk = repeat_kv(xk, self.heads_per_group, dim=2)
|
||||||
|
xv = repeat_kv(xv, self.heads_per_group, dim=2)
|
||||||
|
|
||||||
|
assert mask is None or isinstance(mask, BlockMask)
|
||||||
|
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
||||||
|
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
|
||||||
|
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
||||||
|
|
||||||
|
output = self.wo(output.reshape(output_shape))
|
||||||
|
|
||||||
|
return x + output
|
||||||
|
|
||||||
|
def init_weights(self, base_std: float, factor: float = 1.0):
|
||||||
|
std = base_std * factor
|
||||||
|
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.wq.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=std,
|
||||||
|
a=-3 * std,
|
||||||
|
b=3 * std,
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.wk.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=std,
|
||||||
|
a=-3 * std,
|
||||||
|
b=3 * std,
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.wv.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=std,
|
||||||
|
a=-3 * std,
|
||||||
|
b=3 * std,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_std = std / (2**0.5)
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.wo.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=output_std,
|
||||||
|
a=-3 * output_std,
|
||||||
|
b=3 * output_std,
|
||||||
|
)
|
||||||
|
self.cross_attn_norm_q.reset_parameters()
|
||||||
|
self.cross_attn_norm_kv.reset_parameters()
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalTransformer(BaseTransformer):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__(args)
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.sliding_window = args.sliding_window
|
||||||
|
self.efficient_attn = args.efficient_attn
|
||||||
|
|
||||||
|
self.token_embedding_projection = None
|
||||||
|
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
||||||
|
self.token_embedding_projection = nn.Linear(
|
||||||
|
args.dim_token_emb,
|
||||||
|
args.dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
tok_idx: Optional[torch.Tensor] = None,
|
||||||
|
embeds: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
|
||||||
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Similar to BaseTransformer.forward, but with an additional embeds argument
|
||||||
|
and projection to the token space.
|
||||||
|
"""
|
||||||
|
bs, seqlen = tokens.shape
|
||||||
|
attn_impl = self.efficient_attn
|
||||||
|
|
||||||
|
h = embeds
|
||||||
|
|
||||||
|
mask = (
|
||||||
|
mask
|
||||||
|
if mask is not None
|
||||||
|
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
|
||||||
|
h = self.token_embedding_projection(h)
|
||||||
|
|
||||||
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
||||||
|
return h, cache
|
||||||
|
|
||||||
|
def init_weights(self, init_base_std: float):
|
||||||
|
super().init_weights()
|
||||||
|
if self.token_embedding_projection is not None:
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.token_embedding_projection.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_base_std,
|
||||||
|
a=-3 * init_base_std,
|
||||||
|
b=3 * init_base_std,
|
||||||
|
)
|
116
bytelatent/model/utils.py
Normal file
116
bytelatent/model/utils.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import torch
|
||||||
|
from torch.nn.attention.flex_attention import create_block_mask
|
||||||
|
from xformers.ops import fmha
|
||||||
|
|
||||||
|
|
||||||
|
def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
||||||
|
"""
|
||||||
|
Reduce variable length patches to single embedding per patch
|
||||||
|
Note: this works with variable number of patches for different sequences in the batch
|
||||||
|
It handles variable length patches by assuming that patch_lengths will be 0 for any
|
||||||
|
extra patches on the *right*. Since there can be a variable number of patches
|
||||||
|
this function also return the number of patches for each sequence in the batch.
|
||||||
|
Any embeddings on the right that are not allocated to a patch
|
||||||
|
(i.e. if the sum(patch_lengths[i]) < seq_len for any i)
|
||||||
|
will be sent to a dummy patch, which is trimmed before returning.
|
||||||
|
"""
|
||||||
|
bs, seq_len, emb_dim = h.shape
|
||||||
|
|
||||||
|
patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
|
||||||
|
|
||||||
|
reduced_embs = torch.zeros(
|
||||||
|
(bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device
|
||||||
|
)
|
||||||
|
reduced_embs = reduced_embs.scatter_reduce(
|
||||||
|
src=h,
|
||||||
|
dim=1,
|
||||||
|
index=patch_ids,
|
||||||
|
reduce=reduction,
|
||||||
|
include_self=False,
|
||||||
|
)
|
||||||
|
reduced_embs = reduced_embs[:, :max_num_patches, :]
|
||||||
|
|
||||||
|
return reduced_embs
|
||||||
|
|
||||||
|
|
||||||
|
def concat_downsample(h, patch_lengths, patch_size):
|
||||||
|
# The assumption in this function is that seq_len = patch_size * num_patches.
|
||||||
|
bs, seq_len, emb_dim = h.shape
|
||||||
|
patch_end_ids = torch.cumsum(patch_lengths, dim=1)
|
||||||
|
patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to(
|
||||||
|
patch_end_ids.device
|
||||||
|
)
|
||||||
|
# Is clamp ok here?
|
||||||
|
patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1])
|
||||||
|
patch_ids = patch_ids.view(bs, -1, emb_dim)
|
||||||
|
# after gather h.shape = [batch_size, seq_len, dim]
|
||||||
|
h = torch.gather(h, 1, patch_ids)
|
||||||
|
h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1))
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids):
|
||||||
|
cat = []
|
||||||
|
if "avg" in pooling_mode or "mean" in pooling_mode:
|
||||||
|
cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids))
|
||||||
|
if "min" in pooling_mode:
|
||||||
|
cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids))
|
||||||
|
if "max" in pooling_mode:
|
||||||
|
cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids))
|
||||||
|
assert len(cat) > 0
|
||||||
|
h = torch.cat(cat, dim=-1)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def downsample(
|
||||||
|
h,
|
||||||
|
num_patches,
|
||||||
|
patch_lengths=None,
|
||||||
|
patch_ids=None,
|
||||||
|
downsampling_by_pooling=None,
|
||||||
|
patch_size=4,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Downsampling:
|
||||||
|
a. concatenating embeddings in the patch
|
||||||
|
Note: with dynamic patching, patch the last patch_size tokens.
|
||||||
|
b. pooling embeddings in the patch
|
||||||
|
"""
|
||||||
|
# input: h.shape = [batch_size, seq_len, dim]
|
||||||
|
# input: pool h.shape = [batch_size, seq_len / patch_size, dim]
|
||||||
|
# if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep
|
||||||
|
if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0:
|
||||||
|
# By pooling
|
||||||
|
max_num_patches = num_patches
|
||||||
|
assert patch_ids is not None
|
||||||
|
h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids)
|
||||||
|
else:
|
||||||
|
# TODO: remove this condition
|
||||||
|
# By concatenating (fixed lengths patching)
|
||||||
|
assert patch_lengths is not None
|
||||||
|
h = concat_downsample(h, patch_lengths, patch_size)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def causal_mask(b, h, q_idx, kv_idx):
|
||||||
|
return q_idx >= kv_idx
|
||||||
|
|
||||||
|
|
||||||
|
def create_causal_mask(seqlen, attn_impl, sliding_window):
|
||||||
|
if sliding_window is not None and attn_impl == "xformers":
|
||||||
|
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
|
||||||
|
window_left=sliding_window - 1, window_right=0
|
||||||
|
)
|
||||||
|
elif attn_impl == "xformers":
|
||||||
|
return fmha.attn_bias.LowerTriangularMask()
|
||||||
|
elif attn_impl == "sdpa":
|
||||||
|
return "causal"
|
||||||
|
elif attn_impl == "flex_attention":
|
||||||
|
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
|
||||||
|
elif attn_impl == "fmha":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
|
||||||
|
)
|
162
bytelatent/optim.py
Normal file
162
bytelatent/optim.py
Normal file
|
@ -0,0 +1,162 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import AdamW, lr_scheduler
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class OptimArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
lr: float = 3e-4
|
||||||
|
weight_decay: float = 0.1
|
||||||
|
epsilon: float = 1e-8
|
||||||
|
beta1: float = 0.9
|
||||||
|
beta2: float = 0.95
|
||||||
|
clip: float = 1.0
|
||||||
|
|
||||||
|
scheduler: str = "cosine"
|
||||||
|
warmup: int = 2000
|
||||||
|
lr_min_ratio: float = 0.1
|
||||||
|
cycle_length: float = 1.0
|
||||||
|
cosine_theta: float = 1.0
|
||||||
|
annealing_step: int = 1000
|
||||||
|
decay_fraction: float = 0.1
|
||||||
|
|
||||||
|
exp_factor: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float:
|
||||||
|
if step < warmup:
|
||||||
|
lr = float(step) / warmup
|
||||||
|
elif step <= n_steps:
|
||||||
|
s = float(step - warmup) / (n_steps - warmup)
|
||||||
|
lr = s * min_ratio + (1 - s)
|
||||||
|
else:
|
||||||
|
lr = min_ratio
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def lr_inv_sqrt(step: int, warmup: int, exp_factor: float, min_ratio: float) -> float:
|
||||||
|
if step < warmup:
|
||||||
|
lr = float(step) / warmup
|
||||||
|
else:
|
||||||
|
lr = max((warmup**exp_factor) / (step**exp_factor), min_ratio)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def lr_cosine(
|
||||||
|
step: int,
|
||||||
|
warmup: int,
|
||||||
|
n_steps: int,
|
||||||
|
cycle_length: float,
|
||||||
|
theta: float,
|
||||||
|
min_ratio: float,
|
||||||
|
) -> float:
|
||||||
|
sign = ((step // (n_steps * cycle_length)) % 2) * -2 + 1
|
||||||
|
if step < warmup:
|
||||||
|
lr = float(step) / warmup
|
||||||
|
elif step <= n_steps:
|
||||||
|
s = float(step - warmup) / (n_steps - warmup)
|
||||||
|
lr = min_ratio + 0.5 * (1 - min_ratio) * (
|
||||||
|
sign * math.cos(math.pi * s**theta / cycle_length) + 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lr = min_ratio
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def lr_wsd(
|
||||||
|
step: int,
|
||||||
|
warmup: int,
|
||||||
|
n_steps: int,
|
||||||
|
decay_fraction: float,
|
||||||
|
cycle_length: float,
|
||||||
|
min_ratio: float,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
UNDERSTANDING WARMUP-STABLE-DECAY LEARNING RATES: A RIVER VALLEY LOSS LANDSCAPE PERSPECTIVE
|
||||||
|
https://arxiv.org/pdf/2410.05192
|
||||||
|
"""
|
||||||
|
cycle_num = step // int(n_steps * cycle_length) + 1
|
||||||
|
curr_n_steps = int(n_steps * cycle_length) * cycle_num
|
||||||
|
decay_length = int(curr_n_steps * decay_fraction)
|
||||||
|
|
||||||
|
if step < warmup:
|
||||||
|
lr = float(step) / warmup
|
||||||
|
elif step <= curr_n_steps - decay_length:
|
||||||
|
lr = 1.0
|
||||||
|
elif step > curr_n_steps - decay_length and step <= curr_n_steps:
|
||||||
|
# Linear interpolation gives similar results
|
||||||
|
# slope = -(1.0 - min_ratio) / decay_length
|
||||||
|
# intercept = min_ratio + ((1.0 - min_ratio) * curr_n_steps) / decay_length
|
||||||
|
# lr = slope * step + intercept
|
||||||
|
|
||||||
|
step = step - (curr_n_steps - decay_length)
|
||||||
|
lr = 1 / ((step / curr_n_steps) * (1 / min_ratio) + (1 - step / curr_n_steps))
|
||||||
|
else:
|
||||||
|
lr = min_ratio
|
||||||
|
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def build_lr_fn(args: OptimArgs, n_steps: int):
|
||||||
|
if args.scheduler == "constant":
|
||||||
|
lr_fn = lambda x: 1.0
|
||||||
|
elif args.scheduler == "linear":
|
||||||
|
lr_fn = partial(
|
||||||
|
lr_linear, warmup=args.warmup, n_steps=n_steps, min_ratio=args.lr_min_ratio
|
||||||
|
)
|
||||||
|
elif args.scheduler == "inv_sqrt":
|
||||||
|
lr_fn = partial(
|
||||||
|
lr_inv_sqrt,
|
||||||
|
warmup=args.warmup,
|
||||||
|
exp_factor=args.exp_factor,
|
||||||
|
min_ratio=args.lr_min_ratio,
|
||||||
|
)
|
||||||
|
elif args.scheduler == "cosine":
|
||||||
|
lr_fn = partial(
|
||||||
|
lr_cosine,
|
||||||
|
warmup=args.warmup,
|
||||||
|
n_steps=n_steps,
|
||||||
|
cycle_length=args.cycle_length,
|
||||||
|
theta=args.cosine_theta,
|
||||||
|
min_ratio=args.lr_min_ratio,
|
||||||
|
)
|
||||||
|
elif args.scheduler == "wsd":
|
||||||
|
assert args.decay_fraction < args.cycle_length
|
||||||
|
lr_fn = partial(
|
||||||
|
lr_wsd,
|
||||||
|
warmup=args.warmup,
|
||||||
|
n_steps=n_steps,
|
||||||
|
decay_fraction=args.decay_fraction,
|
||||||
|
cycle_length=args.cycle_length,
|
||||||
|
min_ratio=args.lr_min_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown scheduler: {args.scheduler}")
|
||||||
|
return lr_fn
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int):
|
||||||
|
logger.info("Starting build of optimizer...")
|
||||||
|
optimizer = AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(args.beta1, args.beta2),
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
eps=args.epsilon,
|
||||||
|
fused=True, # Faster optim.step but can throw errors
|
||||||
|
)
|
||||||
|
|
||||||
|
# scheduler
|
||||||
|
lr_fn = build_lr_fn(args, n_steps)
|
||||||
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_fn)
|
||||||
|
|
||||||
|
logger.info("Done with build of optimizer.")
|
||||||
|
return optimizer, scheduler
|
1
bytelatent/plotting/__init__.py
Normal file
1
bytelatent/plotting/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
bytelatent/plotting/config_entropy_figure.yaml
Normal file
3
bytelatent/plotting/config_entropy_figure.yaml
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
data_path: plot_data/entropy_figure.json
|
||||||
|
chart_path: figures/entropy_figure.pdf
|
||||||
|
# chart_path: figures/entropy_figure.pdf
|
4
bytelatent/plotting/config_scaling_figures.yaml
Normal file
4
bytelatent/plotting/config_scaling_figures.yaml
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
df_dir: /home/par/blt_df
|
||||||
|
output_chart_dir: figures/
|
||||||
|
frame_files:
|
||||||
|
["4b_df.json", "500m_df.json", "scaling_arch_df.json", "scaling_df.json"]
|
85
bytelatent/plotting/entropy_figure.py
Normal file
85
bytelatent/plotting/entropy_figure.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import altair as alt
|
||||||
|
import pandas as pd
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PlotEntropiesConfig(BaseModel):
|
||||||
|
data_path: str | None
|
||||||
|
chart_path: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
|
||||||
|
class PlotEntropiesData(BaseModel):
|
||||||
|
text: str
|
||||||
|
threshold: float = 1.335442066192627
|
||||||
|
dataframe_json: str | None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
file_config = OmegaConf.load(config_path)
|
||||||
|
# Omit program name and config file name
|
||||||
|
cli_conf = OmegaConf.from_cli(sys.argv[2:])
|
||||||
|
conf_dict = OmegaConf.to_container(
|
||||||
|
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
|
||||||
|
)
|
||||||
|
plot_config = PlotEntropiesConfig(**conf_dict)
|
||||||
|
with open(plot_config.data_path) as f:
|
||||||
|
json_data = f.read()
|
||||||
|
plot_data = PlotEntropiesData.model_validate_json(json_data)
|
||||||
|
df = pd.read_json(plot_data.dataframe_json)
|
||||||
|
|
||||||
|
x_ticks = []
|
||||||
|
for row in df.itertuples():
|
||||||
|
position = row.position
|
||||||
|
token = row.tokens
|
||||||
|
x_ticks.append(f"{str(position).zfill(3)}|{token}")
|
||||||
|
df["position_with_token"] = x_ticks
|
||||||
|
print(df)
|
||||||
|
|
||||||
|
x_axis = alt.Axis(
|
||||||
|
labelExpr="split(datum.label, '|')[1]",
|
||||||
|
grid=False,
|
||||||
|
labelOverlap=False,
|
||||||
|
labelAngle=0,
|
||||||
|
)
|
||||||
|
width = 1200
|
||||||
|
height = 150
|
||||||
|
base = alt.Chart(df).properties(width=width, height=height)
|
||||||
|
points = base.mark_line(point=True).encode(
|
||||||
|
x=alt.X("position_with_token:O", title=None, axis=x_axis),
|
||||||
|
y=alt.Y(
|
||||||
|
"entropies",
|
||||||
|
title="Entropy of Next Byte",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
|
||||||
|
y=alt.datum(plot_data.threshold),
|
||||||
|
)
|
||||||
|
patch_rules = (
|
||||||
|
alt.Chart(df[df["start"] > 0])
|
||||||
|
.properties(width=width, height=height)
|
||||||
|
.mark_rule(color="#474747", strokeDash=[4, 2])
|
||||||
|
.encode(x=alt.X("position_with_token:O", axis=x_axis))
|
||||||
|
)
|
||||||
|
|
||||||
|
chart = patch_rules + rule + points
|
||||||
|
chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)
|
||||||
|
path = Path(plot_config.chart_path)
|
||||||
|
path.parent.mkdir(exist_ok=True)
|
||||||
|
chart.save(path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
108
bytelatent/plotting/scaling_figures.py
Normal file
108
bytelatent/plotting/scaling_figures.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import altair as alt
|
||||||
|
import pandas as pd
|
||||||
|
import pydantic
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
class ScalingPlotsConfig(pydantic.BaseModel):
|
||||||
|
df_dir: str
|
||||||
|
output_chart_dir: str
|
||||||
|
frame_files: list[str]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
|
||||||
|
def determine_family(key: str):
|
||||||
|
if key.startswith("Megabyte++"):
|
||||||
|
return "Megabyte++"
|
||||||
|
elif key.startswith("BLT"):
|
||||||
|
return "BLT"
|
||||||
|
elif key.startswith("LLaMA"):
|
||||||
|
return "LLaMA"
|
||||||
|
elif key.startswith("Space"):
|
||||||
|
return "Space"
|
||||||
|
|
||||||
|
|
||||||
|
file_to_vars = {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_chart(df: pd.DataFrame, output_file: str):
|
||||||
|
df["metric"] = df["bpb/not_heldout.jsonl"]
|
||||||
|
df["family"] = df["key"].map(determine_family)
|
||||||
|
model_domain = [
|
||||||
|
"BLT Space ps=6",
|
||||||
|
"BLT Space w/o cross-attn",
|
||||||
|
"SpaceByte",
|
||||||
|
"LLaMA 3 BPE",
|
||||||
|
"Megabyte++ ps=4",
|
||||||
|
"Megabyte++ ps=6",
|
||||||
|
]
|
||||||
|
color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"]
|
||||||
|
shape_range = [
|
||||||
|
"circle",
|
||||||
|
"square",
|
||||||
|
"cross",
|
||||||
|
"diamond",
|
||||||
|
"triangle-up",
|
||||||
|
"triangle-down",
|
||||||
|
]
|
||||||
|
color_scale = alt.Scale(domain=model_domain, range=color_range)
|
||||||
|
shape_scale = alt.Scale(
|
||||||
|
domain=model_domain,
|
||||||
|
range=shape_range,
|
||||||
|
)
|
||||||
|
base_chart = alt.Chart(df).encode(
|
||||||
|
x=alt.X("flops", title="Training FLOPS")
|
||||||
|
.scale(type="log", domain=[2e20, 1.25e22])
|
||||||
|
.axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]),
|
||||||
|
y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False),
|
||||||
|
)
|
||||||
|
lines = base_chart.encode(
|
||||||
|
color=alt.Color("key", title="Model Color", scale=color_scale, legend=None),
|
||||||
|
strokeDash=alt.StrokeDash("family", title="Model Family", legend=None),
|
||||||
|
).mark_line()
|
||||||
|
points = base_chart.encode(
|
||||||
|
color=alt.Color("key", title="Model", scale=color_scale),
|
||||||
|
shape=alt.Shape("key", title="", scale=shape_scale),
|
||||||
|
).mark_point(size=70)
|
||||||
|
chart = (
|
||||||
|
(lines + points)
|
||||||
|
.resolve_scale(
|
||||||
|
color="independent",
|
||||||
|
shape="independent",
|
||||||
|
# strokeDash="independent",
|
||||||
|
)
|
||||||
|
.configure_legend(orient="right")
|
||||||
|
.properties(height=300, width=400)
|
||||||
|
)
|
||||||
|
print("Saving", output_file)
|
||||||
|
chart.save(output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
file_config = OmegaConf.load(config_path)
|
||||||
|
# Omit program name and config file name
|
||||||
|
cli_conf = OmegaConf.from_cli(sys.argv[2:])
|
||||||
|
conf_dict = OmegaConf.to_container(
|
||||||
|
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
|
||||||
|
)
|
||||||
|
plot_config = ScalingPlotsConfig(**conf_dict)
|
||||||
|
df_dir = Path(plot_config.df_dir)
|
||||||
|
chart_dir = Path(plot_config.output_chart_dir)
|
||||||
|
chart_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
for ff in plot_config.frame_files:
|
||||||
|
path = df_dir / ff
|
||||||
|
df = pd.read_json(path)
|
||||||
|
print(df)
|
||||||
|
print(df.columns)
|
||||||
|
create_chart(df, chart_dir / f"{path.name}.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
bytelatent/preprocess/__init__.py
Normal file
1
bytelatent/preprocess/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
74
bytelatent/preprocess/data_pipeline.py
Normal file
74
bytelatent/preprocess/data_pipeline.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import luigi
|
||||||
|
|
||||||
|
# CHANGEME: Change this to point to your data
|
||||||
|
BASE_DIR = Path("datasets")
|
||||||
|
DATASETS = ["dclm"]
|
||||||
|
TARGET_DIR = Path("entropy_preprocess")
|
||||||
|
|
||||||
|
SHARD_SCRIPT = """split -C 2500m -d {source} {destination}.shard_"""
|
||||||
|
|
||||||
|
|
||||||
|
def list_dataset_shards(dataset: str):
|
||||||
|
dataset_dir = BASE_DIR / dataset
|
||||||
|
return list(dataset_dir.glob("*.chunk.*.jsonl"))
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkFile(luigi.ExternalTask):
|
||||||
|
file = luigi.Parameter()
|
||||||
|
|
||||||
|
def output(self):
|
||||||
|
return luigi.LocalTarget(self.file)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardDatasetChunk(luigi.Task):
|
||||||
|
dataset_name = luigi.Parameter()
|
||||||
|
chunk_file = luigi.Parameter()
|
||||||
|
|
||||||
|
def _chunk_filename(self):
|
||||||
|
return Path(self.chunk_file).name
|
||||||
|
|
||||||
|
def requires(self):
|
||||||
|
return ChunkFile(self.chunk_file)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
destination_dir = TARGET_DIR / str(self.dataset_name)
|
||||||
|
destination_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
destination = destination_dir / self._chunk_filename()
|
||||||
|
subprocess.check_output(
|
||||||
|
SHARD_SCRIPT.format(source=str(self.chunk_file), destination=destination),
|
||||||
|
shell=True,
|
||||||
|
)
|
||||||
|
(
|
||||||
|
Path(TARGET_DIR)
|
||||||
|
/ str(self.dataset_name)
|
||||||
|
/ f"{self._chunk_filename()}.shard.COMPLETE"
|
||||||
|
).touch()
|
||||||
|
|
||||||
|
def output(self):
|
||||||
|
return luigi.LocalTarget(
|
||||||
|
TARGET_DIR
|
||||||
|
/ str(self.dataset_name)
|
||||||
|
/ f"{self._chunk_filename()}.shard.COMPLETE"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardDataset(luigi.WrapperTask):
|
||||||
|
dataset_name = luigi.Parameter()
|
||||||
|
|
||||||
|
def requires(self):
|
||||||
|
for f in list_dataset_shards(self.dataset_name):
|
||||||
|
yield ShardDatasetChunk(dataset_name=self.dataset_name, chunk_file=str(f))
|
||||||
|
|
||||||
|
|
||||||
|
class ShardAllDatasets(luigi.WrapperTask):
|
||||||
|
def requires(self):
|
||||||
|
for d in DATASETS:
|
||||||
|
yield ShardDataset(dataset_name=d)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
luigi.build([ShardAllDatasets()], local_scheduler=True, workers=128)
|
108
bytelatent/preprocess/parallel_entropies.py
Normal file
108
bytelatent/preprocess/parallel_entropies.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import submitit
|
||||||
|
import typer
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessEntropiesJob(submitit.helpers.Checkpointable):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, shard_file: str, output_filename: str):
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"python",
|
||||||
|
"-u",
|
||||||
|
"-m",
|
||||||
|
"bytelatent.preprocess.preprocess_entropies",
|
||||||
|
str(shard_file),
|
||||||
|
str(output_filename),
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(items, size):
|
||||||
|
for i in range(0, len(items), size):
|
||||||
|
yield items[i : i + size]
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
job_folder: str,
|
||||||
|
input_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
qos: str = "explore",
|
||||||
|
slurm_batch_size: int = 1000,
|
||||||
|
check_only: bool = False,
|
||||||
|
wait: bool = False,
|
||||||
|
):
|
||||||
|
input_dir = Path(input_dir)
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
shard_files = [
|
||||||
|
p for p in input_dir.glob("*.jsonl.shard*") if "COMPLETE" not in p.name
|
||||||
|
]
|
||||||
|
if check_only:
|
||||||
|
exist = []
|
||||||
|
missing = []
|
||||||
|
for shard_file in shard_files:
|
||||||
|
shard_file = Path(shard_file)
|
||||||
|
complete_file = output_dir / f"{shard_file.name}.arrow.complete"
|
||||||
|
if complete_file.exists():
|
||||||
|
exist.append(complete_file)
|
||||||
|
else:
|
||||||
|
missing.append(complete_file)
|
||||||
|
print("Checked for output files for input_dir=", input_dir)
|
||||||
|
print("Exist:", len(exist))
|
||||||
|
print("Missing:", len(missing))
|
||||||
|
print(missing)
|
||||||
|
return
|
||||||
|
print("Running parallel job over N files=", len(shard_files))
|
||||||
|
print("Input Directory:", input_dir)
|
||||||
|
print("Output Directory:", output_dir)
|
||||||
|
output_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
executor = submitit.SlurmExecutor(job_folder)
|
||||||
|
executor.update_parameters(
|
||||||
|
# 12 hours in minutes
|
||||||
|
time=60 * 12,
|
||||||
|
qos=qos,
|
||||||
|
exclusive="user",
|
||||||
|
cpus_per_task=4,
|
||||||
|
num_gpus=1,
|
||||||
|
mem_per_gpu="80G",
|
||||||
|
array_parallelism=slurm_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
jobs = []
|
||||||
|
n_batches = 0
|
||||||
|
n_skipped = 0
|
||||||
|
n_launched = 0
|
||||||
|
for file_batch in chunk(shard_files, slurm_batch_size):
|
||||||
|
with executor.batch():
|
||||||
|
for shard_file in file_batch:
|
||||||
|
output_filename = Path(output_dir) / f"{shard_file.name}.arrow"
|
||||||
|
complete_output_filename = (
|
||||||
|
Path(output_dir) / f"{shard_file.name}.arrow.complete"
|
||||||
|
)
|
||||||
|
if complete_output_filename.exists():
|
||||||
|
n_skipped += 1
|
||||||
|
else:
|
||||||
|
job = executor.submit(
|
||||||
|
PreprocessEntropiesJob(), str(shard_file), str(output_filename)
|
||||||
|
)
|
||||||
|
n_launched += 1
|
||||||
|
jobs.append(job)
|
||||||
|
n_batches += 1
|
||||||
|
print("launched array jobs n=", n_launched)
|
||||||
|
print("skipped (completed) array jobs n=", n_skipped)
|
||||||
|
print("number of slurm batches=", n_batches)
|
||||||
|
if wait:
|
||||||
|
output = [job.result() for job in jobs]
|
||||||
|
assert all(output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
141
bytelatent/preprocess/preprocess_entropies.py
Normal file
141
bytelatent/preprocess/preprocess_entropies.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
import torch
|
||||||
|
import typer
|
||||||
|
from rich.progress import Progress, TextColumn
|
||||||
|
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
input_file: str,
|
||||||
|
output_file: str,
|
||||||
|
patching_device: str = "cuda",
|
||||||
|
log_step: int = 10_000,
|
||||||
|
entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
|
||||||
|
dry_run: bool = False,
|
||||||
|
):
|
||||||
|
# TODO: Modify this to work with the new code
|
||||||
|
raise NotImplementedError()
|
||||||
|
iterator = ArrowFileIterator(
|
||||||
|
file_path=input_file,
|
||||||
|
worker_id=0,
|
||||||
|
num_workers=1,
|
||||||
|
)
|
||||||
|
tokenization_mode = "bytes"
|
||||||
|
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
||||||
|
print("Loading entropy model", entropy_model_checkpoint_dir)
|
||||||
|
if dry_run:
|
||||||
|
return
|
||||||
|
entropy_model = load_entropy_model(
|
||||||
|
entropy_model_checkpoint_dir, device=patching_device
|
||||||
|
)
|
||||||
|
entropy_model, _ = to_device(entropy_model, patching_device)
|
||||||
|
print("Creating patcher")
|
||||||
|
patching_batch_size = 32
|
||||||
|
print("Creating tokenizer")
|
||||||
|
tokenizer = Tokenizer(
|
||||||
|
model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
|
||||||
|
tokenization_mode=tokenization_mode,
|
||||||
|
# BYTE_UNITS
|
||||||
|
vocab_size_unit_1=256,
|
||||||
|
bos=True,
|
||||||
|
eos=True,
|
||||||
|
bpe_delim=False,
|
||||||
|
# This isn't used, just stores a reference for other calls we don't use
|
||||||
|
patcher=None,
|
||||||
|
)
|
||||||
|
step = 0
|
||||||
|
print("starting")
|
||||||
|
start_time = time.time()
|
||||||
|
patch_time = 0
|
||||||
|
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
|
||||||
|
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
|
||||||
|
text_field = pa.field("text", pa.string(), nullable=False)
|
||||||
|
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
||||||
|
arrow_batch_size = 1_000
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pa.OSFile(output_file, "wb") as sink:
|
||||||
|
with pa.ipc.new_file(sink, schema) as writer:
|
||||||
|
id_buffer = []
|
||||||
|
entropies_buffer = []
|
||||||
|
text_buffer = []
|
||||||
|
with Progress(
|
||||||
|
*Progress.get_default_columns(),
|
||||||
|
TextColumn("Completed: {task.completed}"),
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task(
|
||||||
|
"[green]Calculating entropies...", total=None
|
||||||
|
)
|
||||||
|
for doc in iterator:
|
||||||
|
sample_id = get_id_from_doc(doc)
|
||||||
|
|
||||||
|
if "text" in doc:
|
||||||
|
text = doc["text"]
|
||||||
|
elif "content" in doc:
|
||||||
|
text = doc["content"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find a text key from: {doc.keys()}"
|
||||||
|
)
|
||||||
|
tokens = torch.tensor(tokenizer.encode(text))
|
||||||
|
patch_start = time.time()
|
||||||
|
scores = calculate_entropies(
|
||||||
|
tokens,
|
||||||
|
entropy_model,
|
||||||
|
patching_batch_size,
|
||||||
|
patching_device,
|
||||||
|
)
|
||||||
|
entropies_buffer.append(
|
||||||
|
np.array(scores.tolist(), dtype=np.float16)
|
||||||
|
)
|
||||||
|
id_buffer.append(sample_id)
|
||||||
|
text_buffer.append(text)
|
||||||
|
if len(entropies_buffer) == arrow_batch_size:
|
||||||
|
batch = pa.record_batch(
|
||||||
|
{
|
||||||
|
"entropies": entropies_buffer,
|
||||||
|
"sample_id": id_buffer,
|
||||||
|
"text": text_buffer,
|
||||||
|
},
|
||||||
|
schema,
|
||||||
|
)
|
||||||
|
writer.write(batch)
|
||||||
|
entropies_buffer = []
|
||||||
|
id_buffer = []
|
||||||
|
text_buffer = []
|
||||||
|
patch_time += time.time() - patch_start
|
||||||
|
step += 1
|
||||||
|
if step % log_step == 0:
|
||||||
|
print("Completed steps:", step)
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
if len(entropies_buffer) > 0:
|
||||||
|
# Write last things
|
||||||
|
batch = pa.record_batch(
|
||||||
|
{
|
||||||
|
"entropies": entropies_buffer,
|
||||||
|
"sample_id": id_buffer,
|
||||||
|
"text": text_buffer,
|
||||||
|
},
|
||||||
|
schema,
|
||||||
|
)
|
||||||
|
writer.write(batch)
|
||||||
|
entropies_buffer = []
|
||||||
|
id_buffer = []
|
||||||
|
text_buffer = []
|
||||||
|
Path(f"{output_file}.complete").touch()
|
||||||
|
except:
|
||||||
|
Path(output_file).unlink(missing_ok=True)
|
||||||
|
raise
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print("steps", step)
|
||||||
|
print("done in:", elapsed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
694
bytelatent/probe.py
Normal file
694
bytelatent/probe.py
Normal file
|
@ -0,0 +1,694 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
# This file from the xFormers repo is just a example of how to implement
|
||||||
|
# probing of the activations of a model, without changing anything.
|
||||||
|
# By default, the linear inputs/outputs/gradients are logged, as well as
|
||||||
|
# the attention logits+entropy. It is possible to log an additional tensor, eg:
|
||||||
|
# x = log_stats(x, "name")
|
||||||
|
#
|
||||||
|
# Known limitations:
|
||||||
|
# * Only a subset of the attention biases is supported
|
||||||
|
# * Torch-compile is disabled automatically when this is enabled
|
||||||
|
# * Only tested with bf16/f16/f32 datatypes
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
|
CheckpointImpl,
|
||||||
|
checkpoint_wrapper,
|
||||||
|
)
|
||||||
|
from torch.fx.operator_schemas import normalize_function
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
from torch.utils.module_tracker import ModuleTracker
|
||||||
|
from xformers.ops import fmha
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("torchprobe::log", mutates_args=(), device_types=None)
|
||||||
|
def _log(x: torch.Tensor, name: str, uid: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@_log.register_fake
|
||||||
|
def _log_fake(x: torch.Tensor, name: str, uid: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _LogStats(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: torch.Tensor, name: str):
|
||||||
|
uid = str(uuid.uuid4())
|
||||||
|
torch.ops.torchprobe.log(x, name, uid)
|
||||||
|
ctx.name = name
|
||||||
|
ctx.uid = uid
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad: torch.Tensor):
|
||||||
|
torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid)
|
||||||
|
return grad, None
|
||||||
|
|
||||||
|
|
||||||
|
_PROBING_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def log_stats(x: torch.Tensor, name: str) -> torch.Tensor:
|
||||||
|
if not _PROBING_ENABLED:
|
||||||
|
return x
|
||||||
|
return _LogStats.apply(x, name)
|
||||||
|
|
||||||
|
|
||||||
|
QUANTILES = [
|
||||||
|
0.0000001,
|
||||||
|
0.000001,
|
||||||
|
0.00001,
|
||||||
|
0.0001,
|
||||||
|
0.001,
|
||||||
|
0.01,
|
||||||
|
0.05,
|
||||||
|
0.1,
|
||||||
|
0.3,
|
||||||
|
0.5,
|
||||||
|
0.7,
|
||||||
|
0.9,
|
||||||
|
0.95,
|
||||||
|
0.99,
|
||||||
|
0.999,
|
||||||
|
0.9999,
|
||||||
|
0.99999,
|
||||||
|
0.999999,
|
||||||
|
0.9999999,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _get_quantiles(device: torch.device, dtype) -> torch.Tensor:
|
||||||
|
return torch.tensor(QUANTILES, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]:
|
||||||
|
if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]:
|
||||||
|
return {}
|
||||||
|
x = x_.flatten()
|
||||||
|
if remove_inf:
|
||||||
|
x = x[x.abs() < float("inf")]
|
||||||
|
if x.dtype is not torch.double:
|
||||||
|
x = x.float()
|
||||||
|
xabs = x.abs()
|
||||||
|
quantiles = _get_quantiles(x.device, x.dtype)
|
||||||
|
mean = x.mean()
|
||||||
|
std = x.std()
|
||||||
|
return {
|
||||||
|
"shape": tuple(x_.shape),
|
||||||
|
"mean": mean,
|
||||||
|
"std": std,
|
||||||
|
"skew": (((x - mean) / std) ** 3).double().mean(),
|
||||||
|
"kurtosis": (((x - mean) / std) ** 4).double().mean(),
|
||||||
|
"abs.mean": xabs.mean(),
|
||||||
|
"max": x.max(),
|
||||||
|
"min": x.min(),
|
||||||
|
# Note: `quantile` takes at most 2**24 elements, see
|
||||||
|
# https://github.com/pytorch/pytorch/issues/64947
|
||||||
|
"quantiles": torch.quantile(x[: 2**24], quantiles),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None:
|
||||||
|
assert logits.ndim == 4
|
||||||
|
logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_attn_logits(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
q_idx: List[int],
|
||||||
|
*,
|
||||||
|
causal: bool,
|
||||||
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert logits.dtype is torch.float32
|
||||||
|
# Handle BlockDiagonalMask
|
||||||
|
if cu_seqlens_q is not None:
|
||||||
|
assert cu_seqlens_k is not None
|
||||||
|
# Expect BHMqMkv
|
||||||
|
assert logits.ndim == 4, logits.shape
|
||||||
|
qs = cu_seqlens_q.tolist()
|
||||||
|
ks = cu_seqlens_k.tolist()
|
||||||
|
q_batchid = []
|
||||||
|
k_batchid = [-2] * logits.shape[-1]
|
||||||
|
q_idx_i = 0
|
||||||
|
for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])):
|
||||||
|
for k in range(k0, k1):
|
||||||
|
k_batchid[k] = bid
|
||||||
|
while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1:
|
||||||
|
q_batchid.append(bid)
|
||||||
|
if causal:
|
||||||
|
_mask_attn_causal_inplace(
|
||||||
|
logits[:, :, q_idx_i : q_idx_i + 1, k0:k1],
|
||||||
|
q_idx[q_idx_i] - q0,
|
||||||
|
q1 - q0,
|
||||||
|
k1 - k0,
|
||||||
|
)
|
||||||
|
q_idx_i += 1
|
||||||
|
mask_out = (
|
||||||
|
torch.tensor(q_batchid, device=logits.device)[None, None, :, None]
|
||||||
|
!= torch.tensor(k_batchid, device=logits.device)[None, None, None, :]
|
||||||
|
)
|
||||||
|
logits[mask_out.expand_as(logits)] = -math.inf
|
||||||
|
assert q_idx_i == len(q_idx)
|
||||||
|
elif causal:
|
||||||
|
for q_idx_i in range(len(q_idx)):
|
||||||
|
_mask_attn_causal_inplace(
|
||||||
|
logits[:, :, q_idx_i : q_idx_i + 1, :],
|
||||||
|
q_idx[q_idx_i],
|
||||||
|
logits.shape[2],
|
||||||
|
logits.shape[3],
|
||||||
|
)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def _attn_queries_subset(num_queries: int) -> List[int]:
|
||||||
|
return list(range(0, num_queries, max(1, num_queries // 128)))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _compute_attn_stats_sdpa(
|
||||||
|
probe,
|
||||||
|
path: str,
|
||||||
|
# supports arguments both cudnn + flash backends
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_mask=None,
|
||||||
|
attn_bias=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False,
|
||||||
|
scale=None,
|
||||||
|
compute_log_sumexp=True,
|
||||||
|
return_debug_mask=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if scale is None:
|
||||||
|
scale = 1 / (query.shape[-1] ** 0.5)
|
||||||
|
# Filter-out not supported cases
|
||||||
|
if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs:
|
||||||
|
probe.store[f"{path}::attn"] = {
|
||||||
|
"query.shape": tuple(query.shape),
|
||||||
|
"key.shape": tuple(key.shape),
|
||||||
|
"value.shape": tuple(value.shape),
|
||||||
|
"attn_mask": attn_mask.shape if attn_mask is not None else None,
|
||||||
|
"dropout_p": dropout_p,
|
||||||
|
"is_causal": is_causal,
|
||||||
|
"scale": scale,
|
||||||
|
"unk_kwargs": list(kwargs.keys()),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
# Take a subset of the queries and compute the logits
|
||||||
|
query_s = _attn_queries_subset(query.shape[-2])
|
||||||
|
logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale
|
||||||
|
logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal)
|
||||||
|
p = logits.float().softmax(-1)
|
||||||
|
masked_logsoft = logits.log_softmax(-1).where(
|
||||||
|
(logits > -math.inf), torch.zeros_like(logits)
|
||||||
|
)
|
||||||
|
entropy = -(p * masked_logsoft).sum(-1)
|
||||||
|
probe.log_tensor(f"{path}::attn_entropy", entropy)
|
||||||
|
probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _compute_attn_stats_flash(
|
||||||
|
probe,
|
||||||
|
path: str,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
cu_seqlens_q: Optional[torch.Tensor],
|
||||||
|
cu_seqlens_k: Optional[torch.Tensor],
|
||||||
|
seqused_k: Optional[torch.Tensor],
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
p: float,
|
||||||
|
softmax_scale: float,
|
||||||
|
is_causal: bool,
|
||||||
|
window_left: int,
|
||||||
|
window_right: int,
|
||||||
|
return_softmax: bool,
|
||||||
|
block_tables: Optional[torch.Tensor],
|
||||||
|
unpadded_lse: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# Filter-out not supported cases
|
||||||
|
if (
|
||||||
|
seqused_k is not None
|
||||||
|
or p != 0.0
|
||||||
|
or window_left >= 0
|
||||||
|
or window_right >= 0
|
||||||
|
or block_tables is not None
|
||||||
|
):
|
||||||
|
probe.store[f"{path}::attn"] = {
|
||||||
|
"query.shape": tuple(query.shape),
|
||||||
|
"key.shape": tuple(key.shape),
|
||||||
|
"value.shape": tuple(value.shape),
|
||||||
|
"op": "flash",
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
if cu_seqlens_q is not None:
|
||||||
|
assert query.ndim == 3, query.shape
|
||||||
|
query, key, value = query[None], key[None], value[None]
|
||||||
|
assert query.ndim == 4, query.shape
|
||||||
|
|
||||||
|
# Take a subset of the queries and compute the logits
|
||||||
|
query_s = _attn_queries_subset(query.shape[1])
|
||||||
|
logits = (
|
||||||
|
query[:, query_s].transpose(1, 2)
|
||||||
|
@ key.transpose(1, 2).transpose(-1, -2)
|
||||||
|
* softmax_scale
|
||||||
|
)
|
||||||
|
logits = _mask_attn_logits(
|
||||||
|
logits.float(),
|
||||||
|
query_s,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
p = logits.float().softmax(-1)
|
||||||
|
masked_logsoft = logits.log_softmax(-1).where(
|
||||||
|
(logits > -math.inf), torch.zeros_like(logits)
|
||||||
|
)
|
||||||
|
entropy = -(p * masked_logsoft).sum(-1)
|
||||||
|
probe.log_tensor(f"{path}::attn_entropy", entropy)
|
||||||
|
probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensors_to_python(x):
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
return x
|
||||||
|
return x.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
# class syntax
|
||||||
|
class LinearBwType(Enum):
|
||||||
|
DW = 1
|
||||||
|
DX = 2
|
||||||
|
UNKNOWN = 3
|
||||||
|
|
||||||
|
|
||||||
|
class AutoProbeD(TorchDispatchMode):
|
||||||
|
def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None:
|
||||||
|
self.write_file = Path(write_file) if write_file is not None else None
|
||||||
|
self.write_tensors_tmpdir: Optional[Path] = None
|
||||||
|
self.compile_disabler = TorchCompileDisabler(module)
|
||||||
|
self.mod_tracker = ModuleTracker()
|
||||||
|
self.count_per_path: Dict[str, int] = defaultdict(int)
|
||||||
|
self.store: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {}
|
||||||
|
self.uid_to_path: Dict[str, str] = {}
|
||||||
|
self.metadata: Any = None
|
||||||
|
self.enabled = False
|
||||||
|
self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0")))
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
global _PROBING_ENABLED
|
||||||
|
assert not self.enabled, "Entered probe twice"
|
||||||
|
self.compile_disabler.__enter__()
|
||||||
|
self.mod_tracker.__enter__()
|
||||||
|
super().__enter__()
|
||||||
|
self.enabled = True
|
||||||
|
_PROBING_ENABLED = True
|
||||||
|
# self._setup_tensors_logging()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None:
|
||||||
|
global _PROBING_ENABLED
|
||||||
|
assert self.enabled, "Exiting probe without entering it"
|
||||||
|
super().__exit__(*args)
|
||||||
|
self.mod_tracker.__exit__(*args)
|
||||||
|
self.compile_disabler.__exit__(*args)
|
||||||
|
self._flush_and_clear()
|
||||||
|
_PROBING_ENABLED = False
|
||||||
|
self.enabled = False
|
||||||
|
|
||||||
|
def _setup_tensors_logging(self):
|
||||||
|
if self.write_file is not None:
|
||||||
|
self.write_file.parent.mkdir(exist_ok=True)
|
||||||
|
self.write_tensors_tmpdir = (
|
||||||
|
self.write_file.parent
|
||||||
|
/ f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}"
|
||||||
|
)
|
||||||
|
self.write_tensors_tmpdir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
def _flush_and_clear(self) -> None:
|
||||||
|
if self.write_file is not None:
|
||||||
|
dump_data = tree_map(_tensors_to_python, self.store)
|
||||||
|
with self.write_file.open("a") as fd:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"data": dump_data,
|
||||||
|
"meta": self.metadata,
|
||||||
|
"version": 2,
|
||||||
|
"quantiles": QUANTILES,
|
||||||
|
},
|
||||||
|
fd,
|
||||||
|
)
|
||||||
|
fd.write("\n")
|
||||||
|
if self.write_tensors_tmpdir is not None:
|
||||||
|
assert self.write_file is not None
|
||||||
|
dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump"
|
||||||
|
dump_dir.mkdir(exist_ok=True)
|
||||||
|
dir_name = ""
|
||||||
|
if "it" in self.metadata:
|
||||||
|
dir_name = f"it{int(self.metadata['it']):010}"
|
||||||
|
if dir_name == "" or (dump_dir / dir_name).exists():
|
||||||
|
num_files = len(list(dump_dir.glob(f"{dir_name}v*")))
|
||||||
|
dir_name = f"{dir_name}v{num_files}"
|
||||||
|
dump_dir = dump_dir / dir_name
|
||||||
|
assert not dump_dir.exists()
|
||||||
|
self.write_tensors_tmpdir.rename(dump_dir)
|
||||||
|
self.write_tensors_tmpdir = None
|
||||||
|
self.store.clear()
|
||||||
|
self.count_per_path.clear()
|
||||||
|
self.uid_to_path.clear()
|
||||||
|
|
||||||
|
def _find_bw_path_and_type(
|
||||||
|
self, path: str, out: torch.Tensor, args
|
||||||
|
) -> Tuple[str, LinearBwType]:
|
||||||
|
"""
|
||||||
|
We are in the BW pass, and process a GEMM.
|
||||||
|
Let's figure out:
|
||||||
|
(1) The path for the FW pass (might differ in case of ModuleTracker bug)
|
||||||
|
(2) The type of BW pass (eg `dw` or `dx`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _is_path_correct_dw(path: str) -> bool:
|
||||||
|
# dW.t = dY.t @ X
|
||||||
|
in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path]
|
||||||
|
return out.shape == (w_shape[1], w_shape[0]) and torch.allclose(
|
||||||
|
input_sm, args[1][:4, :4]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_path_correct_dx(path: str) -> bool:
|
||||||
|
# dX = dY @ W.t
|
||||||
|
in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path]
|
||||||
|
return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4])
|
||||||
|
|
||||||
|
if path in self.linear_data:
|
||||||
|
if _is_path_correct_dw(path):
|
||||||
|
return path, LinearBwType.DW
|
||||||
|
if _is_path_correct_dx(path):
|
||||||
|
return path, LinearBwType.DX
|
||||||
|
for candidate_path in self.mod_tracker.parents:
|
||||||
|
if candidate_path not in self.linear_data:
|
||||||
|
continue
|
||||||
|
if _is_path_correct_dw(candidate_path):
|
||||||
|
return candidate_path, LinearBwType.DW
|
||||||
|
if _is_path_correct_dx(candidate_path):
|
||||||
|
return candidate_path, LinearBwType.DX
|
||||||
|
return path, LinearBwType.UNKNOWN
|
||||||
|
|
||||||
|
def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None:
|
||||||
|
self.store[name] = _get_stats(x, **kwargs)
|
||||||
|
if self.write_tensors_tmpdir is not None:
|
||||||
|
name_safe = name.replace("::", "__").replace("/", "")
|
||||||
|
torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl")
|
||||||
|
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
kwargs = kwargs if kwargs else {}
|
||||||
|
path = None
|
||||||
|
# Find longest path
|
||||||
|
for p in self.mod_tracker.parents:
|
||||||
|
if p == "Global":
|
||||||
|
continue
|
||||||
|
if path is None or len(p) > len(path):
|
||||||
|
path = p
|
||||||
|
if path is None:
|
||||||
|
path = "Global"
|
||||||
|
path = path.replace("._checkpoint_wrapped_module", "")
|
||||||
|
out = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Handle linear layers
|
||||||
|
if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]:
|
||||||
|
weight: torch.Tensor
|
||||||
|
input: torch.Tensor
|
||||||
|
if not self.mod_tracker.is_bw:
|
||||||
|
# (technically, weight is transposed)
|
||||||
|
if func._overloadpacket == torch.ops.aten.addmm:
|
||||||
|
_bias, input, weight = args[:3]
|
||||||
|
else:
|
||||||
|
assert func._overloadpacket == torch.ops.aten.mm
|
||||||
|
input, weight = args[:2]
|
||||||
|
self.log_tensor(f"{path}::in", input)
|
||||||
|
self.log_tensor(f"{path}::w", weight)
|
||||||
|
self.log_tensor(f"{path}::out", out)
|
||||||
|
self.linear_data[path] = (
|
||||||
|
input.shape,
|
||||||
|
weight.shape,
|
||||||
|
out.shape,
|
||||||
|
input[:4, :4].clone(),
|
||||||
|
weight[:4, :4].T.clone(),
|
||||||
|
)
|
||||||
|
elif func._overloadpacket == torch.ops.aten.mm:
|
||||||
|
# XXX: Try to find the actual path for the linear layer
|
||||||
|
# This is messed with with Francisco's FSDP sometimes
|
||||||
|
new_path, bwtype = self._find_bw_path_and_type(path, out, args)
|
||||||
|
if new_path != path:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"E: Fixing path `{path}` -> `{new_path}")
|
||||||
|
path = new_path
|
||||||
|
|
||||||
|
if bwtype == LinearBwType.DW:
|
||||||
|
# dW.t = dY.t @ X
|
||||||
|
self.log_tensor(f"{path}::w.g", out)
|
||||||
|
elif bwtype == LinearBwType.DX:
|
||||||
|
# dX = dY @ W.t
|
||||||
|
self.log_tensor(f"{path}::in.g", out)
|
||||||
|
self.log_tensor(f"{path}::out.g", args[0])
|
||||||
|
elif func._overloadpacket in [
|
||||||
|
torch.ops.aten._scaled_dot_product_flash_attention,
|
||||||
|
torch.ops.aten._scaled_dot_product_cudnn_attention,
|
||||||
|
]:
|
||||||
|
_, kwargs = normalize_function(
|
||||||
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
|
)
|
||||||
|
_compute_attn_stats_sdpa(self, path, **kwargs)
|
||||||
|
elif func._overloadpacket == fmha.flash.FwOp.OPERATOR:
|
||||||
|
_, kwargs = normalize_function(
|
||||||
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
|
)
|
||||||
|
_compute_attn_stats_flash(self, path, **kwargs)
|
||||||
|
elif func._overloadpacket == torch.ops.torchprobe.log:
|
||||||
|
uid = args[2]
|
||||||
|
path = self.uid_to_path.setdefault(uid, path)
|
||||||
|
self.log_tensor(f"{path}::{args[1]}", args[0])
|
||||||
|
if self.verbose:
|
||||||
|
print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None:
|
||||||
|
if module._compiled_call_impl is not None:
|
||||||
|
out.append(module)
|
||||||
|
for c in module.children():
|
||||||
|
_find_all_submodules_compiled(out, module=c)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchCompileDisabler:
|
||||||
|
def __init__(self, module: nn.Module) -> None:
|
||||||
|
self.module = module
|
||||||
|
self.submodules_compiled: List[nn.Module] = []
|
||||||
|
self.compiled_call_impl: List[Any] = []
|
||||||
|
self.disable_compile = torch.compiler.disable()
|
||||||
|
torch._dynamo.config.raise_on_ctx_manager_usage = False # type: ignore
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
# Remove all `_compiled_call_impl` attributes to effectively
|
||||||
|
# "undo" compilation
|
||||||
|
self.submodules_compiled.clear()
|
||||||
|
_find_all_submodules_compiled(self.submodules_compiled, self.module)
|
||||||
|
self.compiled_call_impl = [
|
||||||
|
m._compiled_call_impl for m in self.submodules_compiled
|
||||||
|
]
|
||||||
|
for m in self.submodules_compiled:
|
||||||
|
m._compiled_call_impl = None
|
||||||
|
self.disable_compile.__enter__() # type: ignore
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None:
|
||||||
|
self.disable_compile.__exit__(*args) # type: ignore
|
||||||
|
for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl):
|
||||||
|
m._compiled_call_impl = c_impl
|
||||||
|
self.compiled_call_impl = []
|
||||||
|
|
||||||
|
|
||||||
|
Probe = AutoProbeD
|
||||||
|
|
||||||
|
# EXAMPLE USAGE
|
||||||
|
d = 512
|
||||||
|
seqlen = 4
|
||||||
|
bs = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Attention1(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask()
|
||||||
|
return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape(
|
||||||
|
[x.shape[0], seqlen, -1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention2(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
||||||
|
[seqlen] * bs
|
||||||
|
).make_causal()
|
||||||
|
xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]])
|
||||||
|
return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape(
|
||||||
|
[x.shape[0], seqlen, -1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionSDPA(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.wo = nn.Linear(d, d)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
return self.wo(
|
||||||
|
F.scaled_dot_product_attention(x, x, x)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape([x.shape[0], seqlen, -1])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionSDPAFlash(AttentionSDPA):
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||||
|
return self.wo(
|
||||||
|
F.scaled_dot_product_attention(x, x, x)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape([x.shape[0], seqlen, -1])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head = nn.Linear(d, 16)
|
||||||
|
self.trunk = nn.Sequential(
|
||||||
|
nn.Linear(d, d),
|
||||||
|
nn.Linear(d, d),
|
||||||
|
)
|
||||||
|
self.q_proj = nn.Linear(d, d, bias=False)
|
||||||
|
self.trunk.compile()
|
||||||
|
self.attn1 = Attention1()
|
||||||
|
self.attn2 = Attention2()
|
||||||
|
self.attnSDPA = AttentionSDPA()
|
||||||
|
self.attnSDPAflash = AttentionSDPAFlash()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, nHeads, D = x.shape[0], d // 64, 64
|
||||||
|
x = self.q_proj(x).reshape([B, seqlen, nHeads, D])
|
||||||
|
x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x)
|
||||||
|
x = log_stats(x, "attns_out")
|
||||||
|
return self.head(self.trunk(x))
|
||||||
|
|
||||||
|
|
||||||
|
def test_masking() -> None:
|
||||||
|
q_seqlen = [1, 1, 14, 12]
|
||||||
|
kv_seqlen = [2, 2, 14, 18]
|
||||||
|
attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||||
|
q_seqlen, kv_seqlen
|
||||||
|
).make_causal_from_bottomright()
|
||||||
|
logits = torch.randn(
|
||||||
|
[1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda"
|
||||||
|
)
|
||||||
|
bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device)
|
||||||
|
logits_masked = logits.clone()
|
||||||
|
_mask_attn_logits(
|
||||||
|
logits_masked,
|
||||||
|
list(range(logits.shape[2])),
|
||||||
|
causal=True,
|
||||||
|
cu_seqlens_q=attn_bias.q_seqinfo.seqstart,
|
||||||
|
cu_seqlens_k=attn_bias.k_seqinfo.seqstart,
|
||||||
|
)
|
||||||
|
assert (logits + bias == logits_masked).all().item()
|
||||||
|
|
||||||
|
|
||||||
|
def test_toy_model() -> None:
|
||||||
|
# Test masking
|
||||||
|
kw = dict(device="cuda", dtype=torch.float16)
|
||||||
|
x = torch.randn([bs, seqlen, d], **kw)
|
||||||
|
m = Model()
|
||||||
|
m.head = checkpoint_wrapper(
|
||||||
|
m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
|
||||||
|
)
|
||||||
|
m.to(**kw)
|
||||||
|
m.compile()
|
||||||
|
optim = torch.optim.SGD(m.parameters(), lr=0.0)
|
||||||
|
probe = AutoProbeD(m, "./probe.json")
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
with contextlib.ExitStack() as stack:
|
||||||
|
print(f"########### STEP {i}")
|
||||||
|
if i % 4 == 1:
|
||||||
|
stack.enter_context(probe)
|
||||||
|
probe.metadata = {"it": i}
|
||||||
|
y = m(x)
|
||||||
|
g = torch.randn_like(y)
|
||||||
|
y.backward(g)
|
||||||
|
if i % 4 == 1:
|
||||||
|
assert probe.enabled
|
||||||
|
# Make sure we registered all linears
|
||||||
|
print(list(probe.store.keys()))
|
||||||
|
for key in [
|
||||||
|
"Model::attns_out",
|
||||||
|
"Model::attns_out.g",
|
||||||
|
"Model.attn1::attn_logits",
|
||||||
|
"Model.attn2::attn_logits",
|
||||||
|
"Model.attnSDPA::attn_logits",
|
||||||
|
"Model.attnSDPAflash::attn_logits",
|
||||||
|
"Model.head::w",
|
||||||
|
"Model.head::w.g",
|
||||||
|
"Model.head::in",
|
||||||
|
"Model.head::in.g",
|
||||||
|
"Model.head::out",
|
||||||
|
"Model.head::out.g",
|
||||||
|
"Model.trunk.0::in",
|
||||||
|
"Model.trunk.1::in",
|
||||||
|
]:
|
||||||
|
assert key in probe.store, f"Missing key: '{key}'"
|
||||||
|
# .. and that the values are correct
|
||||||
|
for key, tensor in [
|
||||||
|
("Model.head::w", m.head.weight),
|
||||||
|
("Model.head::w.g", m.head.weight.grad),
|
||||||
|
("Model.q_proj::in", x),
|
||||||
|
("Model.q_proj::w.g", m.q_proj.weight.grad),
|
||||||
|
("Model.head::out", y),
|
||||||
|
("Model.head::out.g", g),
|
||||||
|
]:
|
||||||
|
assert key in probe.store, f"Missing key: '{key}'"
|
||||||
|
assert torch.allclose(
|
||||||
|
probe.store[key]["abs.mean"], tensor.float().abs().mean()
|
||||||
|
), f"'{key}' mismatches"
|
||||||
|
# Check we don't have `nans`
|
||||||
|
for key, value in probe.store.items():
|
||||||
|
if "abs.mean" in value:
|
||||||
|
assert math.isfinite(
|
||||||
|
value["abs.mean"].item()
|
||||||
|
), f"Inf/Nan for {key}"
|
||||||
|
optim.step()
|
||||||
|
optim.zero_grad()
|
133
bytelatent/profiling.py
Normal file
133
bytelatent/profiling.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch.distributed
|
||||||
|
import wandb
|
||||||
|
import xformers.profiler
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from torch.profiler.profiler import profile
|
||||||
|
from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler
|
||||||
|
|
||||||
|
from bytelatent.distributed import get_is_master
|
||||||
|
|
||||||
|
|
||||||
|
class ProfilerArgs(BaseModel):
|
||||||
|
run: bool = False
|
||||||
|
trace_folder: str = "profiling"
|
||||||
|
mem_warmup: int = 100
|
||||||
|
mem_steps: int = 2
|
||||||
|
profile_warmup: int = 102
|
||||||
|
profile_steps: int = 2
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def perfetto_to_html(json_file, html_file):
|
||||||
|
import gzip
|
||||||
|
import string
|
||||||
|
|
||||||
|
import viztracer
|
||||||
|
|
||||||
|
root = os.path.dirname(viztracer.__file__)
|
||||||
|
sub = {}
|
||||||
|
json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file)
|
||||||
|
with open(
|
||||||
|
os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
tmpl = f.read()
|
||||||
|
with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f:
|
||||||
|
sub["trace_viewer_full"] = f.read()
|
||||||
|
with json_file as j:
|
||||||
|
content = j.read()
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content = content.decode("utf-8")
|
||||||
|
sub["json_data"] = content.replace("</script>", "<\\/script>") # type: ignore
|
||||||
|
with open(html_file, "w+", encoding="utf-8") as output_file:
|
||||||
|
output_file.write(string.Template(tmpl).substitute(sub))
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchProfilerWandb(PyTorchProfiler):
|
||||||
|
def __init__(self, main_profiler) -> None:
|
||||||
|
self.main_profiler = main_profiler
|
||||||
|
self.num_steps = 0
|
||||||
|
self.pytorch_profiler = torch.profiler.profile(
|
||||||
|
on_trace_ready=self._on_trace,
|
||||||
|
profile_memory=True,
|
||||||
|
record_shapes=True,
|
||||||
|
# With stack gives huge profile traces
|
||||||
|
# and bugs out because of some non ascii
|
||||||
|
# character somewhere in pytorch
|
||||||
|
with_stack=False,
|
||||||
|
with_flops=True,
|
||||||
|
activities=self.ACTIVITIES,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_trace(self, prof: profile):
|
||||||
|
logger.info("Begin analyze trace")
|
||||||
|
super()._analyze_trace(prof)
|
||||||
|
logger.info("End analyze trace")
|
||||||
|
|
||||||
|
def _on_trace(self, prof: torch.profiler.profiler.profile) -> None:
|
||||||
|
super()._on_trace(prof)
|
||||||
|
if get_is_master() and wandb.run is not None:
|
||||||
|
filename = list(
|
||||||
|
Path(self.main_profiler.output_dir).glob(
|
||||||
|
"profile_CPU_CUDA*/*.pt.trace.json*"
|
||||||
|
)
|
||||||
|
)[0]
|
||||||
|
html_path = str(filename).replace(".json", ".html")
|
||||||
|
perfetto_to_html(filename, html_path)
|
||||||
|
wandb.log({"profile_trace": wandb.Html(html_path)})
|
||||||
|
|
||||||
|
|
||||||
|
class MemSnapshotsProfilerWandb(MemSnapshotsProfiler):
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
super().__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
if get_is_master() and wandb.run is not None:
|
||||||
|
filename = list(
|
||||||
|
Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html")
|
||||||
|
)[0]
|
||||||
|
wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)})
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def maybe_run_profiler(dump_dir, module, config: ProfilerArgs):
|
||||||
|
# get user defined profiler settings
|
||||||
|
|
||||||
|
if config.run:
|
||||||
|
trace_dir = os.path.join(dump_dir, config.trace_folder)
|
||||||
|
|
||||||
|
logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
|
||||||
|
|
||||||
|
if get_is_master() and not os.path.exists(trace_dir):
|
||||||
|
os.makedirs(trace_dir)
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
with xformers.profiler.profile(
|
||||||
|
output_dir=trace_dir,
|
||||||
|
module=module,
|
||||||
|
schedule=[
|
||||||
|
(
|
||||||
|
MemSnapshotsProfilerWandb,
|
||||||
|
config.mem_warmup,
|
||||||
|
config.mem_warmup + config.mem_steps,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
PyTorchProfilerWandb,
|
||||||
|
config.profile_warmup,
|
||||||
|
config.profile_warmup + config.profile_steps,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) as profiler:
|
||||||
|
yield profiler
|
||||||
|
|
||||||
|
else:
|
||||||
|
torch_profiler = contextlib.nullcontext()
|
||||||
|
yield None
|
237
bytelatent/stool.py
Normal file
237
bytelatent/stool.py
Normal file
|
@ -0,0 +1,237 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StoolArgs:
|
||||||
|
config: Any = None
|
||||||
|
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
|
||||||
|
script: str = "apps.main.train" # The script to run.
|
||||||
|
copy_code: bool = True # Wether to copy code to dump dir
|
||||||
|
dirs_exists_ok: bool = (
|
||||||
|
False # Wether to copy new code and config and run regardless that dir exists
|
||||||
|
)
|
||||||
|
override: bool = False # Wether to delete dump dir and restart
|
||||||
|
nodes: int = -1 # The number of nodes to run the job on.
|
||||||
|
ngpu: int = 8 # The number of GPUs required per node.
|
||||||
|
ncpu: int = 16 # The number of CPUs allocated per GPU.
|
||||||
|
mem: str = "" # The amount of memory to allocate.
|
||||||
|
anaconda: str = "default" # The path to the anaconda environment.
|
||||||
|
constraint: str = "" # The constraint on the nodes.
|
||||||
|
exclude: str = "" # The nodes to exclude.
|
||||||
|
time: int = -1 # The time limit of the job (in minutes).
|
||||||
|
account: str = ""
|
||||||
|
qos: str = ""
|
||||||
|
partition: str = "learn"
|
||||||
|
stdout: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
SBATCH_COMMAND = """#!/bin/bash
|
||||||
|
|
||||||
|
{exclude}
|
||||||
|
{qos}
|
||||||
|
{account}
|
||||||
|
{constraint}
|
||||||
|
#SBATCH --job-name={name}
|
||||||
|
#SBATCH --nodes={nodes}
|
||||||
|
#SBATCH --gres=gpu:{ngpus}
|
||||||
|
#SBATCH --cpus-per-gpu={ncpu}
|
||||||
|
#SBATCH --time={time}
|
||||||
|
#SBATCH --partition={partition}
|
||||||
|
#SBATCH --mem={mem}
|
||||||
|
|
||||||
|
#SBATCH --output={dump_dir}/logs/%j/%j.stdout
|
||||||
|
#SBATCH --error={dump_dir}/logs/%j/%j.stderr
|
||||||
|
|
||||||
|
#SBATCH --open-mode=append
|
||||||
|
#SBATCH --signal=USR2@120
|
||||||
|
#SBATCH --distribution=block
|
||||||
|
|
||||||
|
# Mimic the effect of "conda init", which doesn't work for scripts
|
||||||
|
eval "$({conda_exe} shell.bash hook)"
|
||||||
|
source activate {conda_env_path}
|
||||||
|
|
||||||
|
{go_to_code_dir}
|
||||||
|
|
||||||
|
export OMP_NUM_THREADS=1
|
||||||
|
export LAUNCH_WITH="SBATCH"
|
||||||
|
export DUMP_DIR={dump_dir}
|
||||||
|
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def copy_dir(input_dir: str, output_dir: str) -> None:
|
||||||
|
print(f"Copying : {input_dir}\n" f"to : {output_dir} ...")
|
||||||
|
assert os.path.isdir(input_dir), f"{input_dir} is not a directory"
|
||||||
|
assert os.path.isdir(output_dir), f"{output_dir} is not a directory"
|
||||||
|
rsync_cmd = (
|
||||||
|
f"rsync -arm --copy-links "
|
||||||
|
f"--include '**/' "
|
||||||
|
f"--include '*.py' "
|
||||||
|
f"--exclude='*' "
|
||||||
|
f"{input_dir}/ {output_dir}"
|
||||||
|
)
|
||||||
|
print(f"Copying command: {rsync_cmd}")
|
||||||
|
subprocess.call([rsync_cmd], shell=True)
|
||||||
|
print("Copy done.")
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_max_time_per_partition() -> Dict[str, int]:
|
||||||
|
# retrieve partition max times (a bit slow)
|
||||||
|
|
||||||
|
sinfo = json.loads(subprocess.check_output("sinfo --json", shell=True))["sinfo"]
|
||||||
|
max_times: Dict[str, int] = {}
|
||||||
|
|
||||||
|
for info in sinfo:
|
||||||
|
if info["partition"]["maximums"]["time"]["infinite"]:
|
||||||
|
max_times[info["partition"]["name"]] = 14 * 24 * 60 # 14 days
|
||||||
|
else:
|
||||||
|
max_times[info["partition"]["name"]] = info["partition"]["maximums"][
|
||||||
|
"time"
|
||||||
|
][
|
||||||
|
"number"
|
||||||
|
] # in minutes
|
||||||
|
|
||||||
|
return max_times
|
||||||
|
|
||||||
|
|
||||||
|
def validate_args(args) -> None:
|
||||||
|
# Set maximum time limit if not specified
|
||||||
|
if args.time == -1:
|
||||||
|
max_times = retrieve_max_time_per_partition()
|
||||||
|
args.time = max_times.get(
|
||||||
|
args.partition, 3 * 24 * 60
|
||||||
|
) # Default to 3 days if not found
|
||||||
|
print(
|
||||||
|
f"No time limit specified, using max time for partitions: {args.time} minutes"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.constraint:
|
||||||
|
args.constraint = f"#SBATCH --constraint={args.constraint}"
|
||||||
|
|
||||||
|
if args.account:
|
||||||
|
args.account = f"#SBATCH --account={args.account}"
|
||||||
|
|
||||||
|
if args.qos:
|
||||||
|
args.qos = f"#SBATCH --qos={args.qos}"
|
||||||
|
|
||||||
|
if getattr(args, "exclude", ""):
|
||||||
|
args.exclude = f"#SBATCH --exclude={args.exclude}"
|
||||||
|
|
||||||
|
if hasattr(args, "anaconda") and args.anaconda:
|
||||||
|
if args.anaconda == "default":
|
||||||
|
args.anaconda = (
|
||||||
|
subprocess.check_output("which python", shell=True)
|
||||||
|
.decode("ascii")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
args.anaconda = f"{args.anaconda}/bin/python"
|
||||||
|
assert os.path.isfile(args.anaconda)
|
||||||
|
|
||||||
|
args.mem = args.mem or "0"
|
||||||
|
|
||||||
|
assert args.partition
|
||||||
|
assert args.ngpu > 0
|
||||||
|
assert args.ncpu > 0
|
||||||
|
assert args.nodes > 0
|
||||||
|
assert args.time > 0
|
||||||
|
assert args.partition
|
||||||
|
|
||||||
|
|
||||||
|
def launch_job(args: StoolArgs):
|
||||||
|
# Set up args default and validate them depending on the cluster or partition requested
|
||||||
|
validate_args(args)
|
||||||
|
dump_dir = args.config["dump_dir"]
|
||||||
|
job_name = args.config["name"]
|
||||||
|
print("Creating directories...")
|
||||||
|
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
|
||||||
|
if args.override:
|
||||||
|
confirm = input(
|
||||||
|
f"Are you sure you want to delete the directory '{dump_dir}'? This action cannot be undone. (yes/no): "
|
||||||
|
)
|
||||||
|
if confirm.lower() == "yes":
|
||||||
|
shutil.rmtree(dump_dir)
|
||||||
|
print(f"Directory '{dump_dir}' has been deleted.")
|
||||||
|
else:
|
||||||
|
print("Operation cancelled.")
|
||||||
|
return
|
||||||
|
if args.copy_code:
|
||||||
|
os.makedirs(f"{dump_dir}/code", exist_ok=args.dirs_exists_ok)
|
||||||
|
print("Copying code ...")
|
||||||
|
copy_dir(os.getcwd(), f"{dump_dir}/code")
|
||||||
|
|
||||||
|
print("Saving config file ...")
|
||||||
|
with open(f"{dump_dir}/base_config.yaml", "w") as cfg:
|
||||||
|
cfg.write(OmegaConf.to_yaml(args.config))
|
||||||
|
|
||||||
|
conda_exe = os.environ.get("CONDA_EXE", "conda")
|
||||||
|
conda_env_path = os.path.dirname(os.path.dirname(args.anaconda))
|
||||||
|
log_output = (
|
||||||
|
"-o $DUMP_DIR/logs/%j/%j_%t.out -e $DUMP_DIR/logs/%j/%j_%t.err"
|
||||||
|
if not args.stdout
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
sbatch = SBATCH_COMMAND.format(
|
||||||
|
name=job_name,
|
||||||
|
script=args.script,
|
||||||
|
dump_dir=dump_dir,
|
||||||
|
nodes=args.nodes,
|
||||||
|
tasks=args.nodes * args.ngpu,
|
||||||
|
nodes_per_run=args.nodes,
|
||||||
|
ngpus=args.ngpu,
|
||||||
|
ncpu=args.ncpu,
|
||||||
|
mem=args.mem,
|
||||||
|
qos=args.qos,
|
||||||
|
account=args.account,
|
||||||
|
constraint=args.constraint,
|
||||||
|
exclude=args.exclude,
|
||||||
|
time=args.time,
|
||||||
|
partition=args.partition,
|
||||||
|
conda_exe=conda_exe,
|
||||||
|
conda_env_path=conda_env_path,
|
||||||
|
log_output=log_output,
|
||||||
|
go_to_code_dir=f"cd {dump_dir}/code/" if args.copy_code else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Writing sbatch command ...")
|
||||||
|
with open(f"{dump_dir}/submit.slurm", "w") as f:
|
||||||
|
f.write(sbatch)
|
||||||
|
|
||||||
|
print("Submitting job ...")
|
||||||
|
os.system(f"{args.launcher} {dump_dir}/submit.slurm")
|
||||||
|
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
||||||
|
This accepts arguments as a dot list
|
||||||
|
So if the dataclass looks like
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DummyArgs:
|
||||||
|
name: str
|
||||||
|
mode: LMTransformerArgs
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMTransformerArgs:
|
||||||
|
dim: int
|
||||||
|
|
||||||
|
Then you can pass model.dim=32 to change values in LMTransformerArgs
|
||||||
|
or just name=tictac for top level attributes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Update this to blt code")
|
||||||
|
args = OmegaConf.from_cli()
|
||||||
|
args.config = OmegaConf.load(args.config)
|
||||||
|
args = dataclass_from_dict(StoolArgs, args)
|
||||||
|
launch_job(args)
|
471
bytelatent/test_blt.py
Normal file
471
bytelatent/test_blt.py
Normal file
|
@ -0,0 +1,471 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import os
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from bytelatent.constants import BLT_DATA
|
||||||
|
from bytelatent.data.data_types import Batch
|
||||||
|
from bytelatent.data.ngram_processor import NgramProcessor
|
||||||
|
from bytelatent.model.blt import (
|
||||||
|
ByteLatentTransformer,
|
||||||
|
ByteLatentTransformerArgs,
|
||||||
|
EmbeddingType,
|
||||||
|
compute_hash_embeddings,
|
||||||
|
create_global_transformer,
|
||||||
|
create_local_decoder,
|
||||||
|
create_local_encoder,
|
||||||
|
cross_attn_mask,
|
||||||
|
decoder_patch_ids_from_lengths,
|
||||||
|
get_blt_input,
|
||||||
|
init_embeddings,
|
||||||
|
patch_ids_from_lengths,
|
||||||
|
)
|
||||||
|
from bytelatent.model.transformer import CrossAttention
|
||||||
|
from bytelatent.model.utils import create_causal_mask
|
||||||
|
from bytelatent.optim import OptimArgs, build_optimizer
|
||||||
|
from bytelatent.train import compute_loss
|
||||||
|
|
||||||
|
|
||||||
|
def batch_to_tensors_and_gpu(batch):
|
||||||
|
x = torch.from_numpy(batch.x)
|
||||||
|
y = torch.from_numpy(batch.y)
|
||||||
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask)
|
||||||
|
patch_lengths = (
|
||||||
|
None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths)
|
||||||
|
)
|
||||||
|
ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
x = x.cuda()
|
||||||
|
y = y.cuda()
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.cuda()
|
||||||
|
if patch_lengths is not None:
|
||||||
|
patch_lengths = patch_lengths.cuda()
|
||||||
|
if ngram_ids is not None:
|
||||||
|
ngram_ids = ngram_ids.cuda()
|
||||||
|
return x, y, mask, patch_lengths, ngram_ids
|
||||||
|
|
||||||
|
|
||||||
|
def fake_batch():
|
||||||
|
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"))
|
||||||
|
del batch_dict["x2"]
|
||||||
|
del batch_dict["y2"]
|
||||||
|
del batch_dict["src_names"]
|
||||||
|
return Batch(**batch_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def create_args(cross_attention=False):
|
||||||
|
transformer_args = ByteLatentTransformerArgs(
|
||||||
|
# Base args provided
|
||||||
|
n_heads=8,
|
||||||
|
dim=512,
|
||||||
|
vocab_size=260,
|
||||||
|
# Additional args from command line
|
||||||
|
dim_token=256,
|
||||||
|
patch_size=6,
|
||||||
|
tokenization_mode="bytes",
|
||||||
|
patching_mode="space",
|
||||||
|
tie_local_encoder_decoder_logits=False,
|
||||||
|
data_loader_patching=True,
|
||||||
|
max_encoder_seq_length=12288,
|
||||||
|
pad_to_max_length=True,
|
||||||
|
encoder_lm_loss=False,
|
||||||
|
patching_threshold=3.1439168453216553,
|
||||||
|
encoder_hash_byte_group_size=[4],
|
||||||
|
encoder_hash_byte_group_vocab=50002,
|
||||||
|
encoder_hash_byte_group_nb_functions=3,
|
||||||
|
cross_attn_encoder=cross_attention, # True,
|
||||||
|
cross_attn_decoder=cross_attention, # True,
|
||||||
|
cross_attn_window_encoder=512,
|
||||||
|
cross_attn_window_decoder=512,
|
||||||
|
dim_local_encoder=256,
|
||||||
|
dim_local_decoder=256,
|
||||||
|
cross_attn_k=8,
|
||||||
|
cross_attn_nheads=4,
|
||||||
|
cross_attn_all_layers_decoder=True,
|
||||||
|
cross_attn_all_layers_encoder=True,
|
||||||
|
cross_attn_use_flex_attention=True,
|
||||||
|
cross_attn_init_by_pooling=True,
|
||||||
|
log_patch_lengths=True,
|
||||||
|
non_linearity="swiglu",
|
||||||
|
use_rope=True,
|
||||||
|
recompute_fc1_out=False,
|
||||||
|
recompute_fc3_out=False,
|
||||||
|
recompute_attn=False,
|
||||||
|
custom_bwd=False,
|
||||||
|
layer_ckpt="none",
|
||||||
|
efficient_attn="sdpa",
|
||||||
|
patch_only_encoder=False,
|
||||||
|
patch_only_decoder=False,
|
||||||
|
use_local_encoder_transformer=True,
|
||||||
|
init_use_gaussian=True,
|
||||||
|
init_use_depth="current",
|
||||||
|
attn_bias_type="block_causal",
|
||||||
|
alpha_depth="disabled",
|
||||||
|
max_length=256,
|
||||||
|
local_attention_window_len=512,
|
||||||
|
max_seqlen=12288,
|
||||||
|
downsampling_by_pooling="max",
|
||||||
|
)
|
||||||
|
return transformer_args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
||||||
|
class TestByteLatentTransformer:
|
||||||
|
def test_local_encoder(self):
|
||||||
|
args = create_args()
|
||||||
|
device = torch.device("cuda")
|
||||||
|
local_encoder = create_local_encoder(args).to(device)
|
||||||
|
|
||||||
|
batch = fake_batch()
|
||||||
|
tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
|
||||||
|
|
||||||
|
local_encoder_tokens, _, _ = get_blt_input(
|
||||||
|
tokens=tokens,
|
||||||
|
enforce_patch_size_multiple=False,
|
||||||
|
nb_boe=0,
|
||||||
|
patch_size=local_encoder.patch_size,
|
||||||
|
boe_id=local_encoder.boe_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_ids = patch_ids_from_lengths(
|
||||||
|
patch_lengths, local_encoder_tokens.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_hash_tok_embedding = init_embeddings(
|
||||||
|
args,
|
||||||
|
EmbeddingType.HASH_TOK,
|
||||||
|
local_encoder_dim=local_encoder.dim,
|
||||||
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
local_encoder_embeds = compute_hash_embeddings(
|
||||||
|
local_encoder_tokens=local_encoder_tokens,
|
||||||
|
local_encoder=local_encoder,
|
||||||
|
encoder_hash_tok_embedding=encoder_hash_tok_embedding,
|
||||||
|
encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
|
||||||
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
||||||
|
encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
|
||||||
|
)
|
||||||
|
|
||||||
|
reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt")
|
||||||
|
reference_tokens = torch.load(reference_path).to(device)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
local_encoder_tokens,
|
||||||
|
reference_tokens,
|
||||||
|
msg="Generated tokens don't match reference tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
(h_encoder, h_cross), cache_encoder = local_encoder(
|
||||||
|
tokens=local_encoder_tokens,
|
||||||
|
embeds=local_encoder_embeds,
|
||||||
|
patch_embeds=None,
|
||||||
|
cross_mask=None,
|
||||||
|
num_patches=patch_lengths.shape[1],
|
||||||
|
patch_ids=patch_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert h_encoder is not None
|
||||||
|
assert h_cross is None
|
||||||
|
assert cache_encoder is None
|
||||||
|
|
||||||
|
expected_shape = (
|
||||||
|
local_encoder_tokens.shape[0],
|
||||||
|
local_encoder_tokens.shape[1],
|
||||||
|
local_encoder.dim,
|
||||||
|
)
|
||||||
|
assert h_encoder.shape == expected_shape
|
||||||
|
|
||||||
|
def test_local_encoder_cross_attention(self):
|
||||||
|
args = create_args(cross_attention=True)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
local_encoder = create_local_encoder(args).to(device)
|
||||||
|
|
||||||
|
batch = fake_batch()
|
||||||
|
tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
|
||||||
|
|
||||||
|
local_encoder_tokens, _, _ = get_blt_input(
|
||||||
|
tokens=tokens,
|
||||||
|
enforce_patch_size_multiple=False,
|
||||||
|
nb_boe=0,
|
||||||
|
patch_size=local_encoder.patch_size,
|
||||||
|
boe_id=local_encoder.boe_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_ids = patch_ids_from_lengths(
|
||||||
|
patch_lengths, local_encoder_tokens.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_hash_tok_embedding = init_embeddings(
|
||||||
|
args,
|
||||||
|
EmbeddingType.HASH_TOK,
|
||||||
|
local_encoder_dim=local_encoder.dim,
|
||||||
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
cross_attn_mask_enc = cross_attn_mask(
|
||||||
|
patch_ids,
|
||||||
|
patch_lengths,
|
||||||
|
local_encoder_tokens.shape[-1],
|
||||||
|
patches_as_queries=True,
|
||||||
|
cross_attn_k=args.cross_attn_k,
|
||||||
|
window=args.cross_attn_window_encoder,
|
||||||
|
block_mask=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
local_encoder_embeds = compute_hash_embeddings(
|
||||||
|
local_encoder_tokens=local_encoder_tokens,
|
||||||
|
local_encoder=local_encoder,
|
||||||
|
encoder_hash_tok_embedding=encoder_hash_tok_embedding,
|
||||||
|
encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
|
||||||
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
||||||
|
encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
|
||||||
|
)
|
||||||
|
(h_encoder, h_cross), cache_encoder = local_encoder(
|
||||||
|
tokens=local_encoder_tokens,
|
||||||
|
embeds=local_encoder_embeds,
|
||||||
|
patch_embeds=None,
|
||||||
|
cross_mask=cross_attn_mask_enc,
|
||||||
|
num_patches=patch_lengths.shape[1],
|
||||||
|
patch_ids=patch_ids,
|
||||||
|
)
|
||||||
|
assert h_encoder is not None
|
||||||
|
assert h_cross is not None
|
||||||
|
assert cache_encoder is None
|
||||||
|
expected_shape = (
|
||||||
|
local_encoder_tokens.shape[0],
|
||||||
|
local_encoder_tokens.shape[1],
|
||||||
|
local_encoder.dim,
|
||||||
|
)
|
||||||
|
assert h_encoder.shape == expected_shape
|
||||||
|
assert h_cross.shape == (2, 2048, local_encoder.dim)
|
||||||
|
|
||||||
|
def test_local_decoder_cross_attention(self):
|
||||||
|
args = create_args(cross_attention=True)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
local_decoder = create_local_decoder(args).to(device)
|
||||||
|
|
||||||
|
test_files = {
|
||||||
|
"dec_embeds": "dec_embeds.pt",
|
||||||
|
"decoder_tokens": "local_decoder_tokens.pt",
|
||||||
|
"patch_embeds": "decoder_patch_cross_embeds.pt",
|
||||||
|
}
|
||||||
|
batch = fake_batch()
|
||||||
|
_, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
|
||||||
|
|
||||||
|
tensors = {
|
||||||
|
name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
|
||||||
|
for name, filename in test_files.items()
|
||||||
|
}
|
||||||
|
decoder_patch_ids = decoder_patch_ids_from_lengths(
|
||||||
|
patch_lengths, 0, tensors["decoder_tokens"].shape[-1]
|
||||||
|
)
|
||||||
|
cross_attn_mask_dec = cross_attn_mask(
|
||||||
|
decoder_patch_ids,
|
||||||
|
patch_lengths,
|
||||||
|
tensors["decoder_tokens"].shape[-1],
|
||||||
|
patches_as_queries=False,
|
||||||
|
cross_attn_k=args.cross_attn_k,
|
||||||
|
window=args.cross_attn_window_decoder,
|
||||||
|
block_mask=True,
|
||||||
|
)
|
||||||
|
output, _ = local_decoder(
|
||||||
|
embeds=tensors["dec_embeds"],
|
||||||
|
patch_embeds=tensors["patch_embeds"],
|
||||||
|
tokens=tensors["decoder_tokens"],
|
||||||
|
cross_mask=cross_attn_mask_dec,
|
||||||
|
cache=None,
|
||||||
|
)
|
||||||
|
assert output is not None
|
||||||
|
assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size)
|
||||||
|
|
||||||
|
def test_local_decoder(self):
|
||||||
|
args = create_args()
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
local_decoder = create_local_decoder(args).to(device)
|
||||||
|
|
||||||
|
test_files = {
|
||||||
|
"dec_embeds": "dec_embeds.pt",
|
||||||
|
"decoder_tokens": "local_decoder_tokens.pt",
|
||||||
|
"patch_embeds": "decoder_patch_embeds.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
tensors = {
|
||||||
|
name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
|
||||||
|
for name, filename in test_files.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
output, cache_decoder = local_decoder(
|
||||||
|
embeds=tensors["dec_embeds"],
|
||||||
|
patch_embeds=tensors["patch_embeds"],
|
||||||
|
tokens=tensors["decoder_tokens"],
|
||||||
|
cross_mask=None,
|
||||||
|
cache=None,
|
||||||
|
)
|
||||||
|
assert output is not None
|
||||||
|
expected_shape = (
|
||||||
|
tensors["decoder_tokens"].shape[0],
|
||||||
|
tensors["decoder_tokens"].shape[1],
|
||||||
|
args.vocab_size,
|
||||||
|
)
|
||||||
|
assert output.shape == expected_shape
|
||||||
|
assert cache_decoder is None
|
||||||
|
|
||||||
|
def test_global_transformer(self):
|
||||||
|
args = create_args()
|
||||||
|
device = torch.device("cuda")
|
||||||
|
global_transformer = create_global_transformer(args).to(device)
|
||||||
|
|
||||||
|
test_files = {
|
||||||
|
"global_embeds": "global_embeds.pt",
|
||||||
|
"global_tokens": "global_tokens.pt",
|
||||||
|
}
|
||||||
|
tensors = {
|
||||||
|
name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
|
||||||
|
for name, filename in test_files.items()
|
||||||
|
}
|
||||||
|
h, cache = global_transformer(
|
||||||
|
embeds=tensors["global_embeds"], tokens=tensors["global_tokens"]
|
||||||
|
)
|
||||||
|
h is not None
|
||||||
|
assert h.shape == (2, 256, 512)
|
||||||
|
assert cache is None
|
||||||
|
|
||||||
|
def test_blt_transformer_init(self):
|
||||||
|
args = create_args()
|
||||||
|
model = ByteLatentTransformer(args)
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("attn_type", ["fmha", "sdpa"])
|
||||||
|
def test_blt_transformer_forward(self, attn_type):
|
||||||
|
args = create_args()
|
||||||
|
args = args.model_copy(update=dict(efficient_attn=attn_type))
|
||||||
|
model = ByteLatentTransformer(args)
|
||||||
|
model = model.cuda()
|
||||||
|
batch = fake_batch()
|
||||||
|
x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
|
||||||
|
|
||||||
|
output = model(
|
||||||
|
tokens=x,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
ngram_ids=ngram_ids,
|
||||||
|
)
|
||||||
|
assert output is not None
|
||||||
|
expected_shape = (
|
||||||
|
x.shape[0],
|
||||||
|
x.shape[1],
|
||||||
|
args.vocab_size,
|
||||||
|
)
|
||||||
|
assert output.shape == expected_shape
|
||||||
|
|
||||||
|
def test_blt_transformer_cross_attn_forward(self):
|
||||||
|
args = create_args(cross_attention=True)
|
||||||
|
model = ByteLatentTransformer(args)
|
||||||
|
model = model.cuda()
|
||||||
|
batch = fake_batch()
|
||||||
|
x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
|
||||||
|
|
||||||
|
output = model(
|
||||||
|
tokens=x,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
ngram_ids=ngram_ids,
|
||||||
|
)
|
||||||
|
assert output is not None
|
||||||
|
expected_shape = (
|
||||||
|
x.shape[0],
|
||||||
|
x.shape[1],
|
||||||
|
args.vocab_size,
|
||||||
|
)
|
||||||
|
assert output.shape == expected_shape
|
||||||
|
|
||||||
|
def test_cross_attention_rand(self):
|
||||||
|
x = torch.randn(2, 256, 512, device="cuda")
|
||||||
|
kv = torch.randn(2, 256, 512, device="cuda")
|
||||||
|
cross_attention = CrossAttention(
|
||||||
|
dim=512,
|
||||||
|
head_dim=64,
|
||||||
|
n_heads=8,
|
||||||
|
n_kv_heads=4,
|
||||||
|
norm_eps=1e-6,
|
||||||
|
).to("cuda")
|
||||||
|
mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None)
|
||||||
|
output = cross_attention(x, kv, mask)
|
||||||
|
assert output is not None
|
||||||
|
assert output.shape == (2, 256, 512)
|
||||||
|
|
||||||
|
def test_ngram_embeddings(self):
|
||||||
|
ngram_to_size = {
|
||||||
|
2: 38396,
|
||||||
|
3: 50000,
|
||||||
|
4: 50000,
|
||||||
|
5: 50000,
|
||||||
|
6: 50000,
|
||||||
|
7: 50000,
|
||||||
|
8: 50000,
|
||||||
|
}
|
||||||
|
batch = fake_batch()
|
||||||
|
ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size)
|
||||||
|
ngram_ids = ngram_processor.encode_token_ngrams(batch.x)
|
||||||
|
ngram_ids = np.stack(ngram_ids, axis=0)
|
||||||
|
batch = replace(batch, ngram_ids=ngram_ids)
|
||||||
|
args = create_args(cross_attention=True)
|
||||||
|
args = args.model_copy(
|
||||||
|
update=dict(
|
||||||
|
encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000",
|
||||||
|
encoder_enable_byte_ngrams=True,
|
||||||
|
ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = ByteLatentTransformer(args)
|
||||||
|
model = model.cuda()
|
||||||
|
x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
|
||||||
|
|
||||||
|
output = model(
|
||||||
|
tokens=x,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
ngram_ids=ngram_ids,
|
||||||
|
)
|
||||||
|
assert output is not None
|
||||||
|
expected_shape = (
|
||||||
|
x.shape[0],
|
||||||
|
x.shape[1],
|
||||||
|
args.vocab_size,
|
||||||
|
)
|
||||||
|
assert output.shape == expected_shape
|
||||||
|
|
||||||
|
def test_loss_backward(self):
|
||||||
|
args = create_args()
|
||||||
|
args = args.model_copy(update=dict(efficient_attn="sdpa"))
|
||||||
|
batch = fake_batch()
|
||||||
|
model = ByteLatentTransformer(args)
|
||||||
|
steps = 10
|
||||||
|
optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps)
|
||||||
|
model = model.cuda()
|
||||||
|
x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
|
||||||
|
|
||||||
|
initial_loss = None
|
||||||
|
final_loss = None
|
||||||
|
for step in range(steps):
|
||||||
|
output = model(
|
||||||
|
tokens=x,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
ngram_ids=ngram_ids,
|
||||||
|
)
|
||||||
|
loss, _ = compute_loss(output, y, mask, 1.0)
|
||||||
|
if step == 0:
|
||||||
|
initial_loss = loss.item()
|
||||||
|
if step == steps - 1:
|
||||||
|
final_loss = loss.item()
|
||||||
|
prev_loss = loss.item()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
assert (
|
||||||
|
final_loss < initial_loss
|
||||||
|
), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}"
|
55
bytelatent/test_entropy_model.py
Normal file
55
bytelatent/test_entropy_model.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from bytelatent.constants import BLT_DATA
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
||||||
|
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
||||||
|
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum, entropy
|
||||||
|
from bytelatent.entropy_model import load_entropy_model
|
||||||
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
ENTROPY_MODEL = "transformer_100m"
|
||||||
|
ARROW_TEST_DATA = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
||||||
|
|
||||||
|
|
||||||
|
def test_entropy_model():
|
||||||
|
initial_state = ArrowFileIteratorState(
|
||||||
|
file_path=None,
|
||||||
|
num_workers=1,
|
||||||
|
worker_id=0,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=ENTROPY_MODEL,
|
||||||
|
dataset_files=[ARROW_TEST_DATA],
|
||||||
|
row_num=0,
|
||||||
|
arrow_batch_size=100,
|
||||||
|
)
|
||||||
|
arrow_file = initial_state.build()
|
||||||
|
tokenizer_args = TokenizerArgs(
|
||||||
|
name="blt",
|
||||||
|
init_kwargs={
|
||||||
|
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
entropy_model = load_entropy_model(
|
||||||
|
BLT_DATA / "checkpoint_0100000_consolidated",
|
||||||
|
os.path.join(
|
||||||
|
BLT_DATA,
|
||||||
|
"entropy_model.pth",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
preprocess_iter = PreprocessIterator(
|
||||||
|
arrow_file,
|
||||||
|
tokenizer_args=tokenizer_args,
|
||||||
|
patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.entropy),
|
||||||
|
add_patches=False,
|
||||||
|
)
|
||||||
|
for example in preprocess_iter.create_iter():
|
||||||
|
tokens = torch.tensor(example.tokens).unsqueeze(0)
|
||||||
|
expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
|
||||||
|
preds = entropy_model(tokens)
|
||||||
|
pred_entropies = entropy(preds)
|
||||||
|
assert pred_entropies.shape == expected_entropies.shape
|
||||||
|
assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5)
|
||||||
|
break
|
1
bytelatent/tokenizers/__init__.py
Normal file
1
bytelatent/tokenizers/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
19
bytelatent/tokenizers/abstract_tokenizer.py
Normal file
19
bytelatent/tokenizers/abstract_tokenizer.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def encode(self, text: str, add_bos: bool, add_eos: bool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def decode(self, tokens: list[int]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_token_offsets(
|
||||||
|
self, text: str, tokens: list[int] | None = None
|
||||||
|
) -> tuple[list[str], list[int]]:
|
||||||
|
"""Return the offsets of the tokens in the original text. Only used for evaluation."""
|
||||||
|
pass
|
150
bytelatent/tokenizers/blt_tokenizer.py
Normal file
150
bytelatent/tokenizers/blt_tokenizer.py
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import re
|
||||||
|
|
||||||
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
|
from bytelatent.tokenizers.constants import (
|
||||||
|
BOE_ID,
|
||||||
|
BOS_ID,
|
||||||
|
BPE_ID,
|
||||||
|
BYTE_UNITS,
|
||||||
|
EOS_ID,
|
||||||
|
OFFSET,
|
||||||
|
PAD_ID,
|
||||||
|
)
|
||||||
|
from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_bytes(s):
|
||||||
|
# check if the output is a bytes like object of the format <0x00>
|
||||||
|
if re.match(r"<0x[0-9a-fA-F]+>", s):
|
||||||
|
return bytes.fromhex(s[3:-1])
|
||||||
|
else:
|
||||||
|
return bytes(s, "utf-8", errors="ignore")
|
||||||
|
|
||||||
|
|
||||||
|
def text2bytes_bpe_delims(
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
bpe_tokenizer,
|
||||||
|
bpe_id: int,
|
||||||
|
offsetting_special_char: int,
|
||||||
|
add_bos: bool,
|
||||||
|
add_eos: bool,
|
||||||
|
):
|
||||||
|
cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos)
|
||||||
|
# merge the leading space tokens
|
||||||
|
leading_space_tokens = []
|
||||||
|
other_bpe_tokens = []
|
||||||
|
leading = True
|
||||||
|
for token in cur_bpe:
|
||||||
|
bpe_str = bpe_tokenizer.sp_model.id_to_piece(token)
|
||||||
|
if leading and all(c == "▁" for c in bpe_str):
|
||||||
|
leading_space_tokens.append(bpe_str)
|
||||||
|
else:
|
||||||
|
leading = False
|
||||||
|
other_bpe_tokens.append(bpe_str)
|
||||||
|
cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens
|
||||||
|
|
||||||
|
# Remove the '▁' characters
|
||||||
|
bpe_strs = []
|
||||||
|
for i, bpe_str in enumerate(cur_bpe_strs):
|
||||||
|
if (
|
||||||
|
len(bpe_strs) <= 1
|
||||||
|
and all([c == " " for s in bpe_strs for c in s])
|
||||||
|
and not all(c == "▁" for c in bpe_str)
|
||||||
|
):
|
||||||
|
# Remove leading space for first non space token.
|
||||||
|
bpe_str = bpe_str.replace("▁", "")
|
||||||
|
elif i == 0 and all(c == "▁" for c in bpe_str):
|
||||||
|
bpe_str = " " * (len(text) - len(text.lstrip(" ")))
|
||||||
|
else:
|
||||||
|
bpe_str = bpe_str.replace("▁", " ")
|
||||||
|
if len(bpe_str) > 0:
|
||||||
|
bpe_strs.append(bpe_str)
|
||||||
|
ex_seq = []
|
||||||
|
# Convert bpe tokens to bytes
|
||||||
|
for s in bpe_strs:
|
||||||
|
byte_chunk = convert_to_bytes(s)
|
||||||
|
proc_chunk = [int(unit) for unit in byte_chunk]
|
||||||
|
ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk)
|
||||||
|
|
||||||
|
return ex_seq
|
||||||
|
|
||||||
|
|
||||||
|
class BltTokenizer(Tokenizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vocab_size_unit_1: int = BYTE_UNITS,
|
||||||
|
bpe_delim: bool = False,
|
||||||
|
bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
|
||||||
|
add_bos: bool = True,
|
||||||
|
add_eos: bool = True,
|
||||||
|
):
|
||||||
|
self.add_bos = add_bos
|
||||||
|
self.add_eos = add_eos
|
||||||
|
self.vocab_size_unit_1 = vocab_size_unit_1
|
||||||
|
self.boe_id = BOE_ID
|
||||||
|
self.bos_id = BOS_ID
|
||||||
|
self.eos_id = EOS_ID
|
||||||
|
self.pad_id = PAD_ID
|
||||||
|
self.bpe_id = BPE_ID
|
||||||
|
self.bpe_tokenizer_path = bpe_tokenizer_path
|
||||||
|
if bpe_delim:
|
||||||
|
self.bpe_tokenizer = SentencePieceTokenizer(
|
||||||
|
model_path=self.bpe_tokenizer_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.bpe_tokenizer = None
|
||||||
|
self.bpe_delim = bpe_delim
|
||||||
|
self.offsetting_special_char = OFFSET
|
||||||
|
self.vocab_size_unit_1 = vocab_size_unit_1
|
||||||
|
self.n_words = vocab_size_unit_1 + self.offsetting_special_char
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
|
||||||
|
):
|
||||||
|
if add_bos is None:
|
||||||
|
add_bos = self.add_bos
|
||||||
|
if add_eos is None:
|
||||||
|
add_eos = self.add_eos
|
||||||
|
|
||||||
|
if self.bpe_delim:
|
||||||
|
tokens = text2bytes_bpe_delims(
|
||||||
|
text,
|
||||||
|
bpe_tokenizer=self.bpe_tokenizer,
|
||||||
|
bpe_id=self.bpe_id,
|
||||||
|
offsetting_special_char=self.offsetting_special_char,
|
||||||
|
add_bos=False,
|
||||||
|
add_eos=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokens = bytes(text, encoding="utf-8", errors="ignore")
|
||||||
|
|
||||||
|
# Offsetting
|
||||||
|
tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
|
||||||
|
|
||||||
|
if add_bos:
|
||||||
|
tokens.insert(0, self.bos_id)
|
||||||
|
if add_eos:
|
||||||
|
tokens.append(self.eos_id)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def decode(self, tokens: list[int], cut_at_eos: bool = False):
|
||||||
|
if cut_at_eos:
|
||||||
|
for k, t in enumerate(tokens):
|
||||||
|
if t == self.eos_id:
|
||||||
|
tokens = tokens[: k + 1]
|
||||||
|
break
|
||||||
|
return bytes(
|
||||||
|
[
|
||||||
|
tok - self.offsetting_special_char
|
||||||
|
for tok in tokens
|
||||||
|
if tok - self.offsetting_special_char >= 0
|
||||||
|
]
|
||||||
|
).decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
|
def get_token_offsets(self, text: str, tokens: list[int] | None = None):
|
||||||
|
# TODO: Figure out what this does
|
||||||
|
raise NotImplementedError()
|
69
bytelatent/tokenizers/build_tokenizer.py
Normal file
69
bytelatent/tokenizers/build_tokenizer.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||||
|
from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer
|
||||||
|
from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
has_sp = True
|
||||||
|
except ImportError:
|
||||||
|
has_sp = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
|
||||||
|
has_tiktoken = True
|
||||||
|
except ImportError:
|
||||||
|
has_tiktoken = False
|
||||||
|
|
||||||
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
|
from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizer(Tokenizer):
|
||||||
|
n_words: int = 256
|
||||||
|
|
||||||
|
def encode(self, text: str, add_bos: bool, add_eos: bool):
|
||||||
|
return text
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_token_offsets(
|
||||||
|
self, text: str, tokens: list[int] | None = None
|
||||||
|
) -> tuple[list[str]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerArgs(BaseModel):
|
||||||
|
name: str = "bytes"
|
||||||
|
init_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
def build(self) -> Tokenizer:
|
||||||
|
if self.init_kwargs is None:
|
||||||
|
init_kwargs = {}
|
||||||
|
else:
|
||||||
|
init_kwargs = self.init_kwargs
|
||||||
|
if self.name == "blt":
|
||||||
|
return BltTokenizer(**init_kwargs)
|
||||||
|
elif self.name == "bytes":
|
||||||
|
return ByteTokenizer(**init_kwargs)
|
||||||
|
elif self.name == "mock":
|
||||||
|
return MockTokenizer(**init_kwargs)
|
||||||
|
elif self.name == "sp":
|
||||||
|
assert has_sp, "sentencepiece not installed"
|
||||||
|
return SentencePieceTokenizer(**init_kwargs)
|
||||||
|
elif self.name == "tiktoken":
|
||||||
|
assert has_tiktoken, "tiktoken not installed"
|
||||||
|
return TikTokenTokenizer(**init_kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{self.name} tokenizer type is not implemented")
|
35
bytelatent/tokenizers/byte_tokenizer.py
Normal file
35
bytelatent/tokenizers/byte_tokenizer.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class ByteTokenizer(Tokenizer):
|
||||||
|
def __init__(self):
|
||||||
|
self.bos_id = 256
|
||||||
|
self.eos_id = 257
|
||||||
|
self.n_words = 258
|
||||||
|
|
||||||
|
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
|
||||||
|
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def decode(self, tokens: list[int]):
|
||||||
|
byte_tokens = bytes([t for t in tokens if t < 256])
|
||||||
|
return byte_tokens.decode("utf-8", errors="backslashreplace")
|
||||||
|
|
||||||
|
def get_token_offsets(
|
||||||
|
self, text: str, tokens: list[int] | None = None
|
||||||
|
) -> tuple[list[str], list[int]]:
|
||||||
|
if tokens is None:
|
||||||
|
tokens = self.encode(text)
|
||||||
|
|
||||||
|
decoded_chars, offsets = [], []
|
||||||
|
byte_pos = 0
|
||||||
|
for token in tokens:
|
||||||
|
if token < 256:
|
||||||
|
char = bytes([token]).decode("utf-8", errors="ignore")
|
||||||
|
if char:
|
||||||
|
decoded_chars.append(char)
|
||||||
|
offsets.append(byte_pos)
|
||||||
|
byte_pos += len(char.encode("utf-8"))
|
||||||
|
|
||||||
|
return decoded_chars, offsets
|
12
bytelatent/tokenizers/constants.py
Normal file
12
bytelatent/tokenizers/constants.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
|
||||||
|
SEP = " "
|
||||||
|
BOS_ID: int = 1
|
||||||
|
EOS_ID: int = 2
|
||||||
|
PAD_ID: int = -1
|
||||||
|
BOE_ID: int = 0
|
||||||
|
BPE_ID: int = 3
|
||||||
|
OFFSET: int = 4
|
||||||
|
|
||||||
|
BYTE_UNITS: int = 256
|
59
bytelatent/tokenizers/sentence_piece_tokenizer.py
Normal file
59
bytelatent/tokenizers/sentence_piece_tokenizer.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
has_sp = True
|
||||||
|
except ImportError:
|
||||||
|
has_sp = False
|
||||||
|
|
||||||
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentencePieceTokenizer(Tokenizer):
|
||||||
|
def __init__(
|
||||||
|
self, model_path: str, add_bos: bool = True, add_eos: bool = True
|
||||||
|
) -> None:
|
||||||
|
assert os.path.isfile(model_path), model_path
|
||||||
|
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||||
|
|
||||||
|
logger.info(f"Reloaded SentencePiece model from {model_path}")
|
||||||
|
|
||||||
|
# BOS / EOS token IDs
|
||||||
|
self.n_words: int = self.sp_model.vocab_size()
|
||||||
|
self.bos_id: int = self.sp_model.bos_id()
|
||||||
|
self.eos_id: int = self.sp_model.eos_id()
|
||||||
|
self.pad_id: int = self.sp_model.pad_id()
|
||||||
|
self.add_bos = add_bos
|
||||||
|
self.add_eos = add_eos
|
||||||
|
logger.info(
|
||||||
|
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||||
|
)
|
||||||
|
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||||
|
|
||||||
|
def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
|
||||||
|
if add_bos is None:
|
||||||
|
add_bos = self.add_bos
|
||||||
|
|
||||||
|
if add_eos is None:
|
||||||
|
add_eos = self.add_eos
|
||||||
|
assert type(s) is str
|
||||||
|
tokens = (
|
||||||
|
[self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos
|
||||||
|
)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def decode(self, tokens: list[int]):
|
||||||
|
return self.sp_model.decode(tokens)
|
||||||
|
|
||||||
|
def get_token_offsets(
|
||||||
|
self, text: str, tokens: list[int] | None = None
|
||||||
|
) -> tuple[list[str], list[int]]:
|
||||||
|
pieces = self.sp_model.encode_as_immutable_proto(text).pieces
|
||||||
|
substrs = [p.surface for p in pieces]
|
||||||
|
offsets = [p.begin for p in pieces]
|
||||||
|
return substrs, offsets
|
41
bytelatent/tokenizers/test_blt_tokenizer.py
Normal file
41
bytelatent/tokenizers/test_blt_tokenizer.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import json
|
||||||
|
|
||||||
|
from bytelatent.constants import BLT_DATA
|
||||||
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||||
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenizer_bytes():
|
||||||
|
with open("fixtures/tokenizer_data.json") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
examples: list[str] = data["texts"]
|
||||||
|
examples_tokens: list[list[int]] = data["tokens"]
|
||||||
|
|
||||||
|
tokenizer = BltTokenizer(bpe_delim=False)
|
||||||
|
for i in range(len(examples)):
|
||||||
|
assert tokenizer.encode(examples[i]) == examples_tokens[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenizer_bpe():
|
||||||
|
with open("fixtures/tokenizer_data_bpe_delim.json") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
examples: list[str] = data["texts"]
|
||||||
|
examples_tokens: list[list[int]] = data["tokens"]
|
||||||
|
|
||||||
|
tokenizer = BltTokenizer(bpe_delim=True)
|
||||||
|
for i in range(len(examples)):
|
||||||
|
assert tokenizer.encode(examples[i]) == examples_tokens[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_tokenizer_from_args():
|
||||||
|
tokenizer_args = TokenizerArgs(
|
||||||
|
name="blt",
|
||||||
|
init_kwargs={
|
||||||
|
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_args.build()
|
||||||
|
assert tokenizer.encode("test text") is not None
|
86
bytelatent/tokenizers/tiktoken_tokenizer.py
Normal file
86
bytelatent/tokenizers/tiktoken_tokenizer.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import logging
|
||||||
|
from copy import copy
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
|
||||||
|
has_tiktoken = True
|
||||||
|
except ImportError:
|
||||||
|
has_tiktoken = False
|
||||||
|
DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||||
|
DEFAULT_TIKTOKEN_SPECIAL_TOKENS = {
|
||||||
|
"<|begin_of_text|>": 0,
|
||||||
|
"<|end_of_text|>": 1,
|
||||||
|
"<|fim_prefix|>": 2,
|
||||||
|
"<|fim_middle|>": 3,
|
||||||
|
"<|fim_end_fill|>": 253,
|
||||||
|
"<|fim_pad|>": 254,
|
||||||
|
"<|fim_suffix|>": 255,
|
||||||
|
}
|
||||||
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TikTokenTokenizer(Tokenizer):
|
||||||
|
def __init__(self, model_path: str) -> None:
|
||||||
|
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||||
|
all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS)
|
||||||
|
missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values())
|
||||||
|
for id in missing_ids:
|
||||||
|
all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id
|
||||||
|
for name in all_special_tokens_with_ids:
|
||||||
|
all_special_tokens_with_ids[name] += len(mergeable_ranks)
|
||||||
|
|
||||||
|
self.tkt_model = tiktoken.core.Encoding(
|
||||||
|
name=Path(model_path).stem,
|
||||||
|
pat_str=DEFAULT_TIKTOKEN_PATTERN,
|
||||||
|
mergeable_ranks=mergeable_ranks,
|
||||||
|
special_tokens=all_special_tokens_with_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>")
|
||||||
|
self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>")
|
||||||
|
|
||||||
|
self.n_words: int = self.tkt_model.n_vocab
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, s: str, add_bos: bool, add_eos: bool):
|
||||||
|
assert isinstance(s, str)
|
||||||
|
|
||||||
|
subs = []
|
||||||
|
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
|
||||||
|
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
|
||||||
|
return (
|
||||||
|
[self.bos_id] * add_bos
|
||||||
|
+ sum(self.tkt_model.encode_ordinary_batch(subs), start=[])
|
||||||
|
+ [self.eos_id] * add_eos
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, tokens: list[int]):
|
||||||
|
return self.tkt_model.decode(tokens)
|
||||||
|
|
||||||
|
def get_token_offsets(
|
||||||
|
self, text: str, tokens: list[int] | None = None
|
||||||
|
) -> tuple[list[str], list[int]]:
|
||||||
|
if tokens is not None:
|
||||||
|
token_bytes = self.tkt_model.decode_tokens_bytes(tokens)
|
||||||
|
else:
|
||||||
|
token_bytes = self.tkt_model.decode_tokens_bytes(
|
||||||
|
self.tkt_model.encode(text, allowed_special="all")
|
||||||
|
)
|
||||||
|
|
||||||
|
text_len, offsets = 0, []
|
||||||
|
for token in token_bytes:
|
||||||
|
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
|
||||||
|
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
|
||||||
|
substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])]
|
||||||
|
return substrs, offsets
|
651
bytelatent/train.py
Normal file
651
bytelatent/train.py
Normal file
|
@ -0,0 +1,651 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
from typing import Any, Dict, Type, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn.functional
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import wandb
|
||||||
|
import xformers.profiler
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch.distributed._tensor import DTensor
|
||||||
|
from torch.distributed.checkpoint.stateful import Stateful
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
|
from bytelatent.args import TrainArgs
|
||||||
|
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
||||||
|
from bytelatent.data.data_types import DataLoaderState
|
||||||
|
from bytelatent.distributed import (
|
||||||
|
check_model_value_range,
|
||||||
|
clean_env,
|
||||||
|
dist_mean_dict,
|
||||||
|
get_device_mesh,
|
||||||
|
get_is_master,
|
||||||
|
get_world_size,
|
||||||
|
init_signal_handler,
|
||||||
|
parallelize_model,
|
||||||
|
requeue_slurm_job,
|
||||||
|
setup_env,
|
||||||
|
setup_torch_distributed,
|
||||||
|
)
|
||||||
|
from bytelatent.logger import init_logger
|
||||||
|
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
|
||||||
|
from bytelatent.model.blt import ByteLatentTransformer
|
||||||
|
from bytelatent.optim import build_optimizer
|
||||||
|
from bytelatent.probe import AutoProbeD
|
||||||
|
from bytelatent.profiling import maybe_run_profiler
|
||||||
|
from bytelatent.stool import StoolArgs, launch_job
|
||||||
|
from bytelatent.transformer import (
|
||||||
|
build_fsdp_grouping_plan,
|
||||||
|
get_no_recompute_ops,
|
||||||
|
get_num_flop_per_token,
|
||||||
|
tp_parallelize,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_dict(d, parent_key="", sep="_"):
|
||||||
|
items = []
|
||||||
|
for k, v in d.items():
|
||||||
|
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||||
|
if isinstance(v, dict):
|
||||||
|
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||||
|
else:
|
||||||
|
items.append((new_key, v))
|
||||||
|
return dict(items)
|
||||||
|
|
||||||
|
|
||||||
|
def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
|
||||||
|
"""
|
||||||
|
Converts a dictionary to a dataclass instance, recursively for nested structures.
|
||||||
|
"""
|
||||||
|
base = OmegaConf.structured(cls())
|
||||||
|
OmegaConf.set_struct(base, strict)
|
||||||
|
override = OmegaConf.create(data)
|
||||||
|
return OmegaConf.to_object(OmegaConf.merge(base, override))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainState(Stateful):
|
||||||
|
step: int # Nb of steps taken by the optimizer
|
||||||
|
acc_step: int # Nb of accumulation steps done since last optimizer step
|
||||||
|
scheduler: lr_scheduler.LambdaLR
|
||||||
|
data_loader_state: DataLoaderState
|
||||||
|
scale: float = 1.0
|
||||||
|
|
||||||
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"step": self.step,
|
||||||
|
"acc_step": self.acc_step,
|
||||||
|
"data_loader_state": self.data_loader_state.dict(),
|
||||||
|
"scheduler": self.scheduler.state_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.step = state_dict["step"]
|
||||||
|
self.acc_step = state_dict["acc_step"]
|
||||||
|
self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"])
|
||||||
|
self.scheduler.load_state_dict(state_dict["scheduler"])
|
||||||
|
|
||||||
|
|
||||||
|
def validate_train_args(args: TrainArgs, output_size: int):
|
||||||
|
if args.model.vocab_size < 0:
|
||||||
|
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
||||||
|
args.model.vocab_size = output_size
|
||||||
|
|
||||||
|
assert args.dump_dir, "Dump dir not set"
|
||||||
|
|
||||||
|
if args.checkpoint.path is None:
|
||||||
|
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
||||||
|
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
|
||||||
|
|
||||||
|
for source in args.data.sources:
|
||||||
|
data_path = os.path.join(args.data.root_dir, source)
|
||||||
|
assert os.path.exists(data_path), f"{data_path} doesn't exist"
|
||||||
|
|
||||||
|
if (
|
||||||
|
args.distributed.dp_replicate
|
||||||
|
* args.distributed.dp_shard
|
||||||
|
* args.distributed.tp_size
|
||||||
|
!= get_world_size()
|
||||||
|
):
|
||||||
|
assert get_world_size() % args.distributed.dp_shard == 0
|
||||||
|
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
||||||
|
|
||||||
|
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
||||||
|
args.distributed.dp_replicate = (
|
||||||
|
args.distributed.dp_replicate // args.distributed.tp_size
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
args.distributed.dp_replicate
|
||||||
|
* args.distributed.dp_shard
|
||||||
|
* args.distributed.tp_size
|
||||||
|
== get_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.distributed.fsdp_type == "no_shard":
|
||||||
|
assert (
|
||||||
|
args.distributed.dp_shard == 1
|
||||||
|
and args.distributed.dp_replicate == get_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
args.model.max_seqlen = args.data.seq_len
|
||||||
|
|
||||||
|
if args.distributed.tp_size == 1:
|
||||||
|
logger.warning(
|
||||||
|
"Tensor parallelism has not been tested for a while, use at your own risk"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.probe_freq != args.profiling.mem_steps
|
||||||
|
), "Don't profile during probe step"
|
||||||
|
assert (
|
||||||
|
args.probe_freq != args.profiling.profile_steps
|
||||||
|
), "Don't profile during probe step"
|
||||||
|
if args.logging.wandb is not None:
|
||||||
|
args.logging.wandb.name = args.name
|
||||||
|
|
||||||
|
if args.probe_freq is not None:
|
||||||
|
assert (
|
||||||
|
args.distributed.tp_size == 1
|
||||||
|
), "Probing not supported with tensor parallelism"
|
||||||
|
assert (
|
||||||
|
args.distributed.selective_activation_checkpointing is False
|
||||||
|
), "Probing not supported with selective activation checkpointing"
|
||||||
|
|
||||||
|
|
||||||
|
preemption_flag = dict(flag=False)
|
||||||
|
|
||||||
|
|
||||||
|
def set_preemption_flag(signum, frame):
|
||||||
|
logger.warning("Signal handler called with signal " + str(signum))
|
||||||
|
logger.warning("Preemption ! checkpointing asap and exiting.")
|
||||||
|
preemption_flag["flag"] = True
|
||||||
|
|
||||||
|
|
||||||
|
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
|
||||||
|
test = train_state.step % freq == 0
|
||||||
|
if acc_step is not None:
|
||||||
|
test = test and (train_state.acc_step == acc_step)
|
||||||
|
elif acc_freq is not None:
|
||||||
|
test = test and ((train_state.acc_step % acc_freq) == 0)
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(p, y, mask, scale):
|
||||||
|
tok_loss = scale * F.cross_entropy(
|
||||||
|
p.flatten(0, 1), y.flatten(0, 1), reduction="none"
|
||||||
|
)
|
||||||
|
if mask is None:
|
||||||
|
loss = tok_loss.mean()
|
||||||
|
else:
|
||||||
|
mask = mask.flatten(0, 1)
|
||||||
|
tok_loss = tok_loss * mask
|
||||||
|
loss = tok_loss.sum() / (mask.sum() + 1e-6)
|
||||||
|
return loss, tok_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train(args: TrainArgs):
|
||||||
|
with ExitStack() as context_stack:
|
||||||
|
tokenizer = args.data.tokenizer_args.build()
|
||||||
|
validate_train_args(
|
||||||
|
args,
|
||||||
|
tokenizer.n_words,
|
||||||
|
)
|
||||||
|
if get_is_master():
|
||||||
|
os.makedirs(args.dump_dir, exist_ok=True)
|
||||||
|
args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml")
|
||||||
|
init_logger(Path(args.dump_dir) / "train.log")
|
||||||
|
init_signal_handler(set_preemption_flag) # For handling preemption signals.
|
||||||
|
setup_env(args.env)
|
||||||
|
setup_torch_distributed(args.distributed)
|
||||||
|
world_mesh = get_device_mesh(args.distributed)
|
||||||
|
logger.info(f"Starting job: {args.name}")
|
||||||
|
|
||||||
|
# build dataloader
|
||||||
|
# need dp world size and rank
|
||||||
|
dp_mesh = world_mesh["dp_replicate"]
|
||||||
|
dp_degree = dp_mesh.size()
|
||||||
|
dp_rank = dp_mesh.get_local_rank()
|
||||||
|
if args.distributed.dp_shard > 1:
|
||||||
|
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
|
||||||
|
dp_degree *= world_mesh["dp_shard"].size()
|
||||||
|
|
||||||
|
logger.info(f"Running on dp rank : {dp_rank}")
|
||||||
|
logger.info(f"Running on dp size : {dp_degree}")
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
logger.info("Building model")
|
||||||
|
|
||||||
|
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = ByteLatentTransformer(args.model)
|
||||||
|
logger.info("Model is built !")
|
||||||
|
|
||||||
|
model_param_count = get_num_params(model)
|
||||||
|
|
||||||
|
model = parallelize_model(
|
||||||
|
model,
|
||||||
|
world_mesh,
|
||||||
|
args.model,
|
||||||
|
args.distributed,
|
||||||
|
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
|
||||||
|
tp_parallelize=tp_parallelize,
|
||||||
|
no_recompute_ops=get_no_recompute_ops(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Once we shard the model on different gpus we can actually initialize the model
|
||||||
|
# First we create empty tensors of the correct shapes
|
||||||
|
model = model.to_empty(device="cuda")
|
||||||
|
# Then we init the model. Please make sure this function initializes *ALL* parameters
|
||||||
|
# and buffers, otherwise you will have random values in the unitialized tensors
|
||||||
|
# which will silently fail (give nan gradients for example)
|
||||||
|
|
||||||
|
if args.checkpoint.init_ckpt_path:
|
||||||
|
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
|
||||||
|
load_from_checkpoint(
|
||||||
|
args.checkpoint.init_ckpt_path, model, model_key="model"
|
||||||
|
) # Put model_key="" if its directly the model checkpoint
|
||||||
|
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
||||||
|
else:
|
||||||
|
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
||||||
|
torch.manual_seed(args.model.seed)
|
||||||
|
model.init_weights()
|
||||||
|
check_model_value_range(model, range=10.0, std=1.0)
|
||||||
|
|
||||||
|
# log model size
|
||||||
|
|
||||||
|
logger.info(f"Model size: {model_param_count:,} total parameters")
|
||||||
|
|
||||||
|
gpu_memory_monitor = GPUMemoryMonitor("cuda")
|
||||||
|
logger.info(
|
||||||
|
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
|
||||||
|
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
|
||||||
|
)
|
||||||
|
logger.info(f"GPU memory usage: {gpu_memory_monitor}")
|
||||||
|
|
||||||
|
# build optimizer after apply parallelisms to the model
|
||||||
|
optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
|
||||||
|
data_loader = args.data.build_from_rank(dp_rank, dp_degree)
|
||||||
|
data_loader_state = data_loader.get_state()
|
||||||
|
|
||||||
|
train_state = TrainState(
|
||||||
|
step=0,
|
||||||
|
acc_step=0,
|
||||||
|
data_loader_state=data_loader_state,
|
||||||
|
scheduler=scheduler,
|
||||||
|
scale=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
|
||||||
|
checkpoint.load(model, optimizer, train_state, world_mesh)
|
||||||
|
# Either load from latest checkpoint or start from scratch
|
||||||
|
if args.probe_freq is not None:
|
||||||
|
if get_is_master():
|
||||||
|
os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
probe = AutoProbeD(
|
||||||
|
model,
|
||||||
|
(
|
||||||
|
Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
|
||||||
|
if (dp_rank % 128 == 0)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
probe_mod = model._orig_mod if args.distributed.compile else model
|
||||||
|
|
||||||
|
gc.disable()
|
||||||
|
|
||||||
|
# train loop
|
||||||
|
model.train()
|
||||||
|
metric_logger = context_stack.enter_context(
|
||||||
|
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
|
||||||
|
)
|
||||||
|
data_loader = train_state.data_loader_state.build()
|
||||||
|
batch_iterator = data_loader.create_iter()
|
||||||
|
|
||||||
|
torch_profiler = context_stack.enter_context(
|
||||||
|
maybe_run_profiler(args.dump_dir, model, args.profiling)
|
||||||
|
)
|
||||||
|
|
||||||
|
nwords_since_last_log = 0
|
||||||
|
time_last_log = timer()
|
||||||
|
gc.collect()
|
||||||
|
while train_state.step < args.steps:
|
||||||
|
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
|
||||||
|
train_state.acc_step += 1
|
||||||
|
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
|
||||||
|
|
||||||
|
# get batch
|
||||||
|
curr_lr = float(optimizer.param_groups[0]["lr"])
|
||||||
|
data_load_start = timer()
|
||||||
|
batch = next(batch_iterator)
|
||||||
|
batch_x = torch.from_numpy(
|
||||||
|
batch.x,
|
||||||
|
).cuda()
|
||||||
|
batch_y = torch.from_numpy(batch.y).cuda()
|
||||||
|
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
||||||
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||||
|
|
||||||
|
if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot enable byte ngrams and have batch.ngram_ids be None"
|
||||||
|
)
|
||||||
|
ngram_ids = (
|
||||||
|
None
|
||||||
|
if batch.ngram_ids is None
|
||||||
|
else torch.from_numpy(batch.ngram_ids).cuda()
|
||||||
|
)
|
||||||
|
|
||||||
|
if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
|
||||||
|
logger.info("garbage collection")
|
||||||
|
# we do garbage collection manually otherwise different processes
|
||||||
|
# run the GC at different times so they slow down the whole pipeline
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
data_load_time = round(timer() - data_load_start, 4)
|
||||||
|
nwords_since_last_log += batch_x.numel()
|
||||||
|
|
||||||
|
bsz, seqlen = batch_y.shape
|
||||||
|
|
||||||
|
# forward
|
||||||
|
start_timer = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_timer = torch.cuda.Event(enable_timing=True)
|
||||||
|
start_timer.record()
|
||||||
|
|
||||||
|
# This is an automatic probe that will compute statistics
|
||||||
|
# of all linears' inputs, weights and outputs
|
||||||
|
# along with attention logits and entropy
|
||||||
|
# both in forward and backward pass
|
||||||
|
tok_loss = None
|
||||||
|
if (args.probe_freq is not None) and every_n_steps(
|
||||||
|
train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps
|
||||||
|
):
|
||||||
|
# Here we do a fake forward and backward pass on a smaller
|
||||||
|
# batch size to avoid OOM
|
||||||
|
# This assumes the model has no stateful layers (batch norm..)
|
||||||
|
assert (
|
||||||
|
next(probe_mod.parameters()).grad is None
|
||||||
|
), "Can't probe model if grads are not reset"
|
||||||
|
|
||||||
|
with probe:
|
||||||
|
probe.metadata = {
|
||||||
|
"it": train_state.step,
|
||||||
|
"global_step": train_state.step,
|
||||||
|
"loop": "lingua",
|
||||||
|
}
|
||||||
|
# Non compiled model uses roughly 2x memory in our exps
|
||||||
|
# So we divide bsz by 2 or seqlen by 2
|
||||||
|
probe_bsz = max(1, bsz // 2)
|
||||||
|
probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2)
|
||||||
|
probe_loss = probe_mod(
|
||||||
|
batch_x[:probe_bsz, :probe_seq],
|
||||||
|
batch_y[:probe_bsz, :probe_seq],
|
||||||
|
)
|
||||||
|
probe_loss.backward()
|
||||||
|
# We zero grads to cancel this fake step
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
next(probe_mod.parameters()).grad is None
|
||||||
|
), "Probe model shouldn't have grads at this point"
|
||||||
|
|
||||||
|
pred = model(
|
||||||
|
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
|
||||||
|
|
||||||
|
# We scale loss with grad_acc_steps so the gradient is the same
|
||||||
|
# regardless of grad_acc_steps
|
||||||
|
loss = loss / args.grad_acc_steps
|
||||||
|
|
||||||
|
# backward on scaled loss to create scaled gradients
|
||||||
|
loss.backward()
|
||||||
|
# For logging we undo that scaling
|
||||||
|
loss = loss.detach() * args.grad_acc_steps
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_norm = (
|
||||||
|
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|
||||||
|
).item()
|
||||||
|
|
||||||
|
# optimizer step
|
||||||
|
if train_state.acc_step == 0:
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
train_state.step += 1
|
||||||
|
|
||||||
|
# updates the scale for next iteration
|
||||||
|
# training iteration complete
|
||||||
|
end_timer.record()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
|
||||||
|
|
||||||
|
# if profiler is active
|
||||||
|
if torch_profiler:
|
||||||
|
xformers.profiler.step()
|
||||||
|
|
||||||
|
# log metrics
|
||||||
|
if every_n_steps(
|
||||||
|
train_state,
|
||||||
|
args.logging.freq,
|
||||||
|
acc_step=None if args.logging.acc_freq else 0,
|
||||||
|
acc_freq=args.logging.acc_freq,
|
||||||
|
):
|
||||||
|
time_delta = timer() - time_last_log
|
||||||
|
wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
|
||||||
|
|
||||||
|
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
|
||||||
|
|
||||||
|
total_acc_steps = (
|
||||||
|
args.grad_acc_steps * train_state.step + train_state.acc_step
|
||||||
|
)
|
||||||
|
tokens_per_gpu = (
|
||||||
|
total_acc_steps * args.data.batch_size * args.data.seq_len
|
||||||
|
)
|
||||||
|
total_tokens = dp_degree * tokens_per_gpu
|
||||||
|
# This is an estimate and the correct values may change
|
||||||
|
# if you change the architecture
|
||||||
|
# Use xformer's analyze profile trace to get actual measurement
|
||||||
|
FLOPS = (
|
||||||
|
get_num_flop_per_token(
|
||||||
|
model_param_count - args.model.vocab_size * args.model.dim,
|
||||||
|
args.model.n_layers,
|
||||||
|
args.model.dim,
|
||||||
|
args.data.seq_len,
|
||||||
|
)
|
||||||
|
* wps
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = flatten_dict(
|
||||||
|
{
|
||||||
|
"global_step": train_state.step,
|
||||||
|
"acc_step": train_state.acc_step,
|
||||||
|
"speed": {
|
||||||
|
"wps": wps,
|
||||||
|
"FLOPS": FLOPS,
|
||||||
|
"curr_iter_time": curr_iter_time,
|
||||||
|
"data_load_time": data_load_time,
|
||||||
|
},
|
||||||
|
"optim": {
|
||||||
|
"grad_norm": grad_norm,
|
||||||
|
"lr": curr_lr,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
},
|
||||||
|
"memory": gpu_mem_stats._asdict(),
|
||||||
|
},
|
||||||
|
sep="/",
|
||||||
|
)
|
||||||
|
|
||||||
|
to_sync = {}
|
||||||
|
to_sync["loss/out"] = loss.item()
|
||||||
|
metrics.update(dist_mean_dict(to_sync))
|
||||||
|
|
||||||
|
if get_is_master():
|
||||||
|
metric_logger.log(metrics)
|
||||||
|
|
||||||
|
gpu_memory_monitor.reset_peak_stats()
|
||||||
|
nwords_since_last_log = 0
|
||||||
|
time_last_log = timer()
|
||||||
|
logger.info(
|
||||||
|
f"step: {train_state.step}"
|
||||||
|
f" acc: {train_state.acc_step}"
|
||||||
|
f" loss: {round(loss.item(),4):>7}"
|
||||||
|
f" grad: {grad_norm:.2e}"
|
||||||
|
f" flops: {FLOPS:.2e}"
|
||||||
|
f" wps: {wps:.2e}"
|
||||||
|
f" iter: {curr_iter_time:>7}"
|
||||||
|
f" data: {data_load_time:>5}"
|
||||||
|
f" lr: {curr_lr:.2e}"
|
||||||
|
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
|
||||||
|
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
||||||
|
)
|
||||||
|
|
||||||
|
saved = False
|
||||||
|
if every_n_steps(
|
||||||
|
train_state, args.checkpoint.dump.every, acc_step=0
|
||||||
|
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||||
|
saved = checkpoint.save(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_state,
|
||||||
|
args,
|
||||||
|
device_mesh=world_mesh,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.eval is not None and every_n_steps(
|
||||||
|
train_state, args.checkpoint.eval.every, acc_step=0
|
||||||
|
):
|
||||||
|
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
||||||
|
|
||||||
|
eval_args = dataclass_from_dict(EvalArgs, args.eval)
|
||||||
|
|
||||||
|
eval_args.global_step = train_state.step
|
||||||
|
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
|
||||||
|
eval_args.dump_dir = str(
|
||||||
|
os.path.join(
|
||||||
|
args.dump_dir,
|
||||||
|
"evals",
|
||||||
|
EVAL_FOLDER_NAME.format(train_state.step),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
eval_args.metric_log_dir = args.dump_dir
|
||||||
|
if args.async_eval_gpus is None:
|
||||||
|
launch_eval(eval_args)
|
||||||
|
elif get_is_master():
|
||||||
|
if wandb.run is not None and args.logging.wandb is not None:
|
||||||
|
eval_args.wandb = deepcopy(args.logging.wandb)
|
||||||
|
assert args.async_eval_gpus > 0
|
||||||
|
logger.info(f"Launching evals on {args.async_eval_gpus} gpus")
|
||||||
|
with clean_env():
|
||||||
|
launch_job(
|
||||||
|
StoolArgs(
|
||||||
|
asdict(eval_args),
|
||||||
|
script="apps.main.eval",
|
||||||
|
copy_code=False,
|
||||||
|
nodes=args.async_eval_gpus // 8,
|
||||||
|
qos="lowest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if preemption_flag["flag"]:
|
||||||
|
if not saved:
|
||||||
|
checkpoint.save(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_state,
|
||||||
|
args,
|
||||||
|
device_mesh=world_mesh,
|
||||||
|
)
|
||||||
|
requeue_slurm_job()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
if not saved:
|
||||||
|
checkpoint.save(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_state,
|
||||||
|
args,
|
||||||
|
device_mesh=world_mesh,
|
||||||
|
)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
||||||
|
This accepts arguments as a dot list
|
||||||
|
So if the dataclass looks like
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DummyArgs:
|
||||||
|
name: str
|
||||||
|
model: LMTransformerArgsgs
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMTransformerArgsgs:
|
||||||
|
dim: int
|
||||||
|
|
||||||
|
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
||||||
|
or just name=tictac for top level attributes.
|
||||||
|
|
||||||
|
The behavior here is as follows:
|
||||||
|
1. We instantiate TrainArgs with its default values
|
||||||
|
2. We override those default values with the ones in the provided config file
|
||||||
|
3. We override the result with the additional arguments provided through command line
|
||||||
|
|
||||||
|
For example, if the config is the following
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 128
|
||||||
|
n_layers: 4
|
||||||
|
|
||||||
|
and you call train.py with train.py model.dim=64
|
||||||
|
|
||||||
|
Then the final TrainArgs will have
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 64
|
||||||
|
n_layers: 4
|
||||||
|
|
||||||
|
Plus all the default values in TrainArgs dataclass.
|
||||||
|
"""
|
||||||
|
cli_args = OmegaConf.from_cli()
|
||||||
|
file_cfg = OmegaConf.load(cli_args.config)
|
||||||
|
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
||||||
|
del cli_args.config
|
||||||
|
|
||||||
|
default_cfg = OmegaConf.create(TrainArgs().model_dump())
|
||||||
|
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||||
|
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||||||
|
train_args = TrainArgs.model_validate(cfg)
|
||||||
|
train(train_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
216
bytelatent/transformer.py
Normal file
216
bytelatent/transformer.py
Normal file
|
@ -0,0 +1,216 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.distributed._tensor import Replicate, Shard
|
||||||
|
from torch.distributed.tensor.parallel import (
|
||||||
|
ColwiseParallel,
|
||||||
|
PrepareModuleInput,
|
||||||
|
RowwiseParallel,
|
||||||
|
SequenceParallel,
|
||||||
|
parallelize_module,
|
||||||
|
)
|
||||||
|
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
||||||
|
from xformers.ops import AttentionBias, fmha
|
||||||
|
|
||||||
|
from bytelatent.base_transformer import (
|
||||||
|
BaseTransformer,
|
||||||
|
BaseTransformerArgs,
|
||||||
|
RMSNorm,
|
||||||
|
cross_entropy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_causal_mask(seqlen, attn_impl, sliding_window):
|
||||||
|
if sliding_window is not None and attn_impl == "xformers":
|
||||||
|
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
|
||||||
|
window_left=sliding_window - 1, window_right=0
|
||||||
|
)
|
||||||
|
elif attn_impl == "xformers":
|
||||||
|
return fmha.attn_bias.LowerTriangularMask()
|
||||||
|
elif attn_impl == "sdpa":
|
||||||
|
return "causal"
|
||||||
|
elif attn_impl == "flex_attention":
|
||||||
|
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
||||||
|
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
||||||
|
return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_flop_per_token(
|
||||||
|
num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
|
||||||
|
) -> int:
|
||||||
|
return 6 * num_non_embed_params + attention_flops_per_token(
|
||||||
|
n_layers, seq_len, dim, True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def causal_mask(b, h, q_idx, kv_idx):
|
||||||
|
return q_idx >= kv_idx
|
||||||
|
|
||||||
|
|
||||||
|
class LMTransformerArgs(BaseTransformerArgs):
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
|
vocab_size: int = -1
|
||||||
|
weight_tying: bool = False
|
||||||
|
|
||||||
|
sliding_window: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LMTransformer(BaseTransformer):
|
||||||
|
def __init__(self, args: LMTransformerArgs):
|
||||||
|
super().__init__(args)
|
||||||
|
self.weight_tying = args.weight_tying
|
||||||
|
self.sliding_window = args.sliding_window
|
||||||
|
|
||||||
|
assert args.vocab_size > 0
|
||||||
|
|
||||||
|
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
|
||||||
|
self.output = nn.Linear(
|
||||||
|
args.dim,
|
||||||
|
args.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.weight_tying:
|
||||||
|
self.output.weight = self.embeddings.tok_embeddings.weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
token_values: torch.Tensor,
|
||||||
|
target: Optional[torch.Tensor] = None,
|
||||||
|
tok_idx: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
|
||||||
|
attn_impl: str = "sdpa",
|
||||||
|
):
|
||||||
|
bsz, seqlen = token_values.shape
|
||||||
|
|
||||||
|
h = self.tok_embeddings(token_values)
|
||||||
|
|
||||||
|
mask = (
|
||||||
|
mask
|
||||||
|
if mask is not None
|
||||||
|
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
|
||||||
|
)
|
||||||
|
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
||||||
|
|
||||||
|
logits = self.output(self.norm(h))
|
||||||
|
if target is not None:
|
||||||
|
return cross_entropy(logits, target)
|
||||||
|
else:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def reset_parameters(self, init_std=None):
|
||||||
|
# Either use fixed base std or sqrt model dim
|
||||||
|
super().reset_parameters()
|
||||||
|
init_std = init_std or (self.dim ** (-0.5))
|
||||||
|
self.norm.reset_parameters()
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.tok_embeddings.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
if not self.weight_tying:
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.output.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
|
||||||
|
def get_no_recompute_ops():
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Optional and only used for fully shard options (fsdp) is choose. Highly recommanded for large models
|
||||||
|
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
||||||
|
group_plan: Tuple[int, bool] = []
|
||||||
|
|
||||||
|
# Grouping and output seperately
|
||||||
|
group_plan.append(("tok_embeddings", False))
|
||||||
|
|
||||||
|
# Grouping by layers
|
||||||
|
for i in range(model_args.n_layers):
|
||||||
|
group_plan.append((f"layers.{i}", False))
|
||||||
|
|
||||||
|
group_plan.append(("output", True))
|
||||||
|
|
||||||
|
return group_plan
|
||||||
|
|
||||||
|
|
||||||
|
# Optional and only used for model/tensor parallelism when tp_size > 1
|
||||||
|
def tp_parallelize(model, tp_mesh, model_args: LMTransformerArgs, distributed_args):
|
||||||
|
assert model_args.dim % distributed_args.tp_size == 0
|
||||||
|
assert model_args.vocab_size % distributed_args.tp_size == 0
|
||||||
|
assert model_args.n_heads % distributed_args.tp_size == 0
|
||||||
|
assert (model_args.n_kv_heads or 0) % distributed_args.tp_size == 0
|
||||||
|
assert model_args.n_heads % (model_args.n_kv_heads or 1) == 0
|
||||||
|
|
||||||
|
# Embedding layer tp
|
||||||
|
main_plan = {}
|
||||||
|
main_plan["tok_embeddings"] = ColwiseParallel(
|
||||||
|
input_layouts=Replicate(), output_layouts=Shard(1)
|
||||||
|
)
|
||||||
|
main_plan["norm"] = SequenceParallel()
|
||||||
|
main_plan["output"] = ColwiseParallel(
|
||||||
|
input_layouts=Shard(1), output_layouts=Replicate()
|
||||||
|
)
|
||||||
|
|
||||||
|
parallelize_module(
|
||||||
|
model,
|
||||||
|
tp_mesh,
|
||||||
|
main_plan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attention layers tp
|
||||||
|
for layer in model.layers:
|
||||||
|
layer_plan = {}
|
||||||
|
|
||||||
|
layer_plan["attention"] = PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), None),
|
||||||
|
desired_input_layouts=(Replicate(), None),
|
||||||
|
)
|
||||||
|
layer_plan["attention_norm"] = SequenceParallel()
|
||||||
|
layer_plan["attention.wq"] = ColwiseParallel()
|
||||||
|
layer_plan["attention.wk"] = ColwiseParallel()
|
||||||
|
layer_plan["attention.wv"] = ColwiseParallel()
|
||||||
|
layer_plan["attention.wo"] = RowwiseParallel(output_layouts=Shard(1))
|
||||||
|
|
||||||
|
# Feedforward layers tp
|
||||||
|
layer_plan["feed_forward"] = PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1),),
|
||||||
|
desired_input_layouts=(Replicate(),),
|
||||||
|
)
|
||||||
|
layer_plan["ffn_norm"] = SequenceParallel()
|
||||||
|
layer_plan["feed_forward.w1"] = ColwiseParallel()
|
||||||
|
layer_plan["feed_forward.w3"] = ColwiseParallel()
|
||||||
|
layer_plan["feed_forward.w2"] = RowwiseParallel(output_layouts=Shard(1))
|
||||||
|
|
||||||
|
parallelize_module(
|
||||||
|
layer,
|
||||||
|
tp_mesh,
|
||||||
|
layer_plan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adjusting the number of heads and kv heads according to the tp size
|
||||||
|
attn_layer = layer.attention
|
||||||
|
attn_layer.n_heads = attn_layer.n_heads // distributed_args.tp_size
|
||||||
|
attn_layer.n_kv_heads = attn_layer.n_kv_heads // distributed_args.tp_size
|
3
dev/lint.sh
Normal file
3
dev/lint.sh
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
isort .
|
||||||
|
black .
|
1
fixtures/tokenizer_data.json
Normal file
1
fixtures/tokenizer_data.json
Normal file
|
@ -0,0 +1 @@
|
||||||
|
{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 80, 105, 120, 43, 119, 36, 103, 108, 105, 103, 111, 36, 109, 106, 36, 120, 108, 105, 119, 105, 36, 120, 115, 111, 105, 114, 109, 126, 105, 118, 119, 36, 113, 101, 120, 103, 108, 37, 2]]}
|
1
fixtures/tokenizer_data_bpe_delim.json
Normal file
1
fixtures/tokenizer_data_bpe_delim.json
Normal file
|
@ -0,0 +1 @@
|
||||||
|
{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 3, 80, 105, 120, 3, 43, 3, 119, 3, 36, 103, 108, 105, 103, 111, 3, 36, 109, 106, 3, 36, 120, 108, 105, 119, 105, 3, 36, 120, 115, 111, 105, 114, 3, 109, 126, 105, 118, 119, 3, 36, 113, 101, 120, 103, 108, 3, 37, 2]]}
|
1
fixtures/tokens_with_entropies.json
Normal file
1
fixtures/tokens_with_entropies.json
Normal file
|
@ -0,0 +1 @@
|
||||||
|
{"position":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":6,"7":7,"8":8,"9":9,"10":10,"11":11,"12":12,"13":13,"14":14,"15":15,"16":16,"17":17,"18":18,"19":19,"20":20,"21":21,"22":22,"23":23,"24":24,"25":25,"26":26,"27":27,"28":28,"29":29,"30":30,"31":31,"32":32,"33":33,"34":34,"35":35,"36":36,"37":37,"38":38,"39":39,"40":40,"41":41,"42":42,"43":43,"44":44,"45":45,"46":46,"47":47,"48":48,"49":49,"50":50,"51":51,"52":52,"53":53,"54":54,"55":55,"56":56,"57":57,"58":58,"59":59,"60":60,"61":61,"62":62,"63":63,"64":64,"65":65,"66":66,"67":67,"68":68,"69":69,"70":70,"71":71,"72":72,"73":73,"74":74,"75":75,"76":76,"77":77,"78":78,"79":79,"80":80},"tokens":{"0":"<","1":"D","2":"a","3":"e","4":"n","5":"e","6":"r","7":"y","8":"s","9":"_","10":"T","11":"a","12":"r","13":"g","14":"a","15":"r","16":"y","17":"e","18":"n","19":"_","20":"i","21":"s","22":"_","23":"i","24":"n","25":"_","26":"G","27":"a","28":"m","29":"e","30":"_","31":"o","32":"f","33":"_","34":"T","35":"h","36":"r","37":"o","38":"n","39":"e","40":"s","41":",","42":"_","43":"a","44":"_","45":"f","46":"a","47":"n","48":"t","49":"a","50":"s","51":"y","52":"_","53":"e","54":"p","55":"i","56":"c","57":"_","58":"b","59":"y","60":"_","61":"G","62":"e","63":"o","64":"r","65":"g","66":"e","67":"_","68":"R","69":".","70":"R","71":".","72":"_","73":"M","74":"a","75":"r","76":"t","77":"i","78":"n","79":".","80":">"},"token_ids":{"0":1,"1":72,"2":101,"3":105,"4":114,"5":105,"6":118,"7":125,"8":119,"9":36,"10":88,"11":101,"12":118,"13":107,"14":101,"15":118,"16":125,"17":105,"18":114,"19":36,"20":109,"21":119,"22":36,"23":109,"24":114,"25":36,"26":75,"27":101,"28":113,"29":105,"30":36,"31":115,"32":106,"33":36,"34":88,"35":108,"36":118,"37":115,"38":114,"39":105,"40":119,"41":48,"42":36,"43":101,"44":36,"45":106,"46":101,"47":114,"48":120,"49":101,"50":119,"51":125,"52":36,"53":105,"54":116,"55":109,"56":103,"57":36,"58":102,"59":125,"60":36,"61":75,"62":105,"63":115,"64":118,"65":107,"66":105,"67":36,"68":86,"69":50,"70":86,"71":50,"72":36,"73":81,"74":101,"75":118,"76":120,"77":109,"78":114,"79":50,"80":2},"entropies":{"0":3.3949158192,"1":2.1656451225,"2":2.3216569424,"3":2.8214058876,"4":1.5249242783,"5":0.0401624143,"6":0.0981037766,"7":0.0544578359,"8":0.3430138826,"9":1.0546212196,"10":0.25252828,"11":0.1494535804,"12":0.0624754503,"13":0.001355894,"14":0.0050173439,"15":0.0052358187,"16":0.0011725067,"17":0.0010307421,"18":1.0241208076,"19":3.6867966652,"20":0.4502205253,"21":0.0484119244,"22":2.2572875023,"23":0.3789347112,"24":1.0042934418,"25":2.9090054035,"26":1.8933598995,"27":1.3859074116,"28":0.3827198744,"29":0.2646365762,"30":1.7742085457,"31":0.0136727821,"32":0.0053820172,"33":0.5485631227,"34":0.2064044327,"35":0.0049266233,"36":0.0005439016,"37":0.0007023578,"38":0.0004170335,"39":0.0054524317,"40":1.1938130856,"41":0.0238215197,"42":3.1279797554,"43":1.3883389235,"44":3.0503094196,"45":1.695879817,"46":1.8551058769,"47":1.4570231438,"48":0.0047810897,"49":0.026396824,"50":0.6633765101,"51":0.3141393065,"52":2.8411159515,"53":1.143143177,"54":0.0520330966,"55":0.3398066461,"56":0.4140175879,"57":2.5563707352,"58":1.3370712996,"59":0.0227173548,"60":3.4447185993,"61":1.8576486111,"62":0.8189754486,"63":0.6776530743,"64":0.0677763447,"65":0.212713033,"66":0.1003480032,"67":0.1746164262,"68":0.4123829603,"69":0.5507118702,"70":0.1047425047,"71":0.0194335245,"72":0.001482119,"73":0.0009310447,"74":0.0002176317,"75":0.0076908777,"76":0.0003866984,"77":0.0008008487,"78":1.2395234108,"79":0.4564163089,"80":0.0000461392},"patch":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":5,"7":5,"8":5,"9":5,"10":5,"11":5,"12":5,"13":5,"14":5,"15":5,"16":5,"17":5,"18":5,"19":5,"20":6,"21":6,"22":6,"23":7,"24":7,"25":7,"26":8,"27":9,"28":10,"29":10,"30":10,"31":11,"32":11,"33":11,"34":11,"35":11,"36":11,"37":11,"38":11,"39":11,"40":11,"41":11,"42":11,"43":12,"44":13,"45":14,"46":15,"47":16,"48":17,"49":17,"50":17,"51":17,"52":17,"53":18,"54":18,"55":18,"56":18,"57":18,"58":19,"59":20,"60":20,"61":21,"62":22,"63":22,"64":22,"65":22,"66":22,"67":22,"68":22,"69":22,"70":22,"71":22,"72":22,"73":22,"74":22,"75":22,"76":22,"77":22,"78":22,"79":22,"80":22},"start":{"0":1,"1":1,"2":1,"3":1,"4":1,"5":1,"6":0,"7":0,"8":0,"9":0,"10":0,"11":0,"12":0,"13":0,"14":0,"15":0,"16":0,"17":0,"18":0,"19":0,"20":1,"21":0,"22":0,"23":1,"24":0,"25":0,"26":1,"27":1,"28":1,"29":0,"30":0,"31":1,"32":0,"33":0,"34":0,"35":0,"36":0,"37":0,"38":0,"39":0,"40":0,"41":0,"42":0,"43":1,"44":1,"45":1,"46":1,"47":1,"48":1,"49":0,"50":0,"51":0,"52":0,"53":1,"54":0,"55":0,"56":0,"57":0,"58":1,"59":1,"60":0,"61":1,"62":1,"63":0,"64":0,"65":0,"66":0,"67":0,"68":0,"69":0,"70":0,"71":0,"72":0,"73":0,"74":0,"75":0,"76":0,"77":0,"78":0,"79":0,"80":0}}
|
1
plot_data/entropy_figure.json
Normal file
1
plot_data/entropy_figure.json
Normal file
File diff suppressed because one or more lines are too long
5
pyproject.toml
Normal file
5
pyproject.toml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
known_bytelatent = "bytelatent"
|
||||||
|
known_apps = "apps"
|
||||||
|
sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER"
|
22
requirements.txt
Normal file
22
requirements.txt
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
numpy
|
||||||
|
omegaconf
|
||||||
|
msgspec
|
||||||
|
rouge-score
|
||||||
|
sacrebleu
|
||||||
|
sentencepiece
|
||||||
|
tiktoken
|
||||||
|
fsspec
|
||||||
|
blobfile
|
||||||
|
wandb
|
||||||
|
viztracer
|
||||||
|
lm-eval
|
||||||
|
scipy
|
||||||
|
pynvml
|
||||||
|
datatrove
|
||||||
|
orjson
|
||||||
|
luigi
|
||||||
|
pydantic
|
||||||
|
altair
|
||||||
|
submitit
|
||||||
|
typer
|
||||||
|
rich
|
48
setup/create_env.sh
Normal file
48
setup/create_env.sh
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
#SBATCH --job-name=env_creation
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gres=gpu:8
|
||||||
|
#SBATCH --exclusive
|
||||||
|
#SBATCH --ntasks-per-node=1
|
||||||
|
#SBATCH --cpus-per-task=128
|
||||||
|
#SBATCH --mem=0
|
||||||
|
#SBATCH --time=01:00:00
|
||||||
|
|
||||||
|
# Exit immediately if a command exits with a non-zero status
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Start timer
|
||||||
|
start_time=$(date +%s)
|
||||||
|
|
||||||
|
# Get the current date
|
||||||
|
current_date=$(date +%y%m%d)
|
||||||
|
|
||||||
|
# Create environment name with the current date
|
||||||
|
env_prefix=blt_$current_date
|
||||||
|
|
||||||
|
# Create the conda environment
|
||||||
|
|
||||||
|
source $CONDA_ROOT/etc/profile.d/conda.sh
|
||||||
|
conda create -n $env_prefix python=3.11 -y -c anaconda
|
||||||
|
conda activate $env_prefix
|
||||||
|
|
||||||
|
echo "Currently in env $(which python)"
|
||||||
|
|
||||||
|
# Install packages
|
||||||
|
pip install torch==2.5.0 xformers --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
pip install ninja
|
||||||
|
pip install --requirement requirements.txt
|
||||||
|
|
||||||
|
# End timer
|
||||||
|
end_time=$(date +%s)
|
||||||
|
|
||||||
|
# Calculate elapsed time in seconds
|
||||||
|
elapsed_time=$((end_time - start_time))
|
||||||
|
|
||||||
|
# Convert elapsed time to minutes
|
||||||
|
elapsed_minutes=$((elapsed_time / 60))
|
||||||
|
|
||||||
|
echo "Environment $env_prefix created and all packages installed successfully in $elapsed_minutes minutes!"
|
156
setup/download_prepare_hf_data.py
Normal file
156
setup/download_prepare_hf_data.py
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(command):
|
||||||
|
print(f"Running: {command}")
|
||||||
|
subprocess.run(command, shell=True, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def download_dataset(repo_id, local_dir, allow_patterns):
|
||||||
|
print(f"Downloading dataset from {repo_id}...")
|
||||||
|
max_retries = 5
|
||||||
|
retry_delay = 10 # seconds
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=local_dir,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
resume_download=True,
|
||||||
|
max_workers=16, # Don't hesitate to increase this number to lower the download time
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except requests.exceptions.ReadTimeout:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
print(f"Timeout occurred. Retrying in {retry_delay} seconds...")
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
print(f"Dataset downloaded to {local_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
|
||||||
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
|
from datatrove.pipeline.readers import ParquetReader
|
||||||
|
from datatrove.pipeline.writers import JsonlWriter
|
||||||
|
|
||||||
|
pipeline_exec = LocalPipelineExecutor(
|
||||||
|
pipeline=[
|
||||||
|
ParquetReader(
|
||||||
|
src_dir,
|
||||||
|
file_progress=True,
|
||||||
|
doc_progress=True,
|
||||||
|
glob_pattern="**/*.parquet",
|
||||||
|
),
|
||||||
|
JsonlWriter(
|
||||||
|
tgt_dir,
|
||||||
|
output_filename=dataset + ".chunk.${rank}.jsonl",
|
||||||
|
compression=None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tasks=ntasks,
|
||||||
|
logging_dir=os.path.join(work_dir, "datatrove"),
|
||||||
|
)
|
||||||
|
pipeline_exec.run()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_terashuf(work_dir):
|
||||||
|
terashuf_dir = os.path.join(work_dir, "terashuf")
|
||||||
|
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
|
||||||
|
|
||||||
|
if os.path.exists(terashuf_executable):
|
||||||
|
print("terashuf executable already exists. Skipping setup.")
|
||||||
|
return terashuf_dir
|
||||||
|
|
||||||
|
print("Setting up terashuf...")
|
||||||
|
run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}")
|
||||||
|
run_command(f"make -C {terashuf_dir}")
|
||||||
|
return terashuf_dir
|
||||||
|
|
||||||
|
|
||||||
|
def main(dataset, memory, data_dir, seed=42, nchunks=32):
|
||||||
|
# Configuration
|
||||||
|
repo_id = {
|
||||||
|
"fineweb_edu": "HuggingFaceFW/fineweb-edu",
|
||||||
|
"fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu",
|
||||||
|
"dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0",
|
||||||
|
"dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0",
|
||||||
|
}[dataset]
|
||||||
|
src_dir = f"{data_dir}/{dataset}"
|
||||||
|
out_dir = f"{src_dir}_shuffled"
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
work_dir = src_dir # Directory of this Python file
|
||||||
|
prefix = f"{dataset}.chunk."
|
||||||
|
orig_extension = {
|
||||||
|
"fineweb_edu": ".jsonl",
|
||||||
|
"fineweb_edu_10bt": ".jsonl",
|
||||||
|
"dclm_baseline_1.0": ".jsonl.zst",
|
||||||
|
"dclm_baseline_1.0_10prct": ".jsonl.zst",
|
||||||
|
}[dataset]
|
||||||
|
cat_command = {
|
||||||
|
"fineweb_edu": "cat",
|
||||||
|
"fineweb_edu_10bt": "cat",
|
||||||
|
"dclm_baseline_1.0": "zstdcat",
|
||||||
|
"dclm_baseline_1.0_10prct": "zstdcat",
|
||||||
|
}[dataset]
|
||||||
|
allow_patterns = {
|
||||||
|
"fineweb_edu": None,
|
||||||
|
"fineweb_edu_10bt": "sample/10BT/*",
|
||||||
|
"dclm_baseline_1.0": "*.jsonl.zst",
|
||||||
|
"dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst",
|
||||||
|
}[dataset]
|
||||||
|
suffix = ".jsonl"
|
||||||
|
k_validation = 10000 # Number of lines to take from each chunk for validation
|
||||||
|
|
||||||
|
# Setup terashuf
|
||||||
|
terashuf_dir = setup_terashuf(work_dir)
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
download_dataset(repo_id, src_dir, allow_patterns)
|
||||||
|
|
||||||
|
if "fineweb" in dataset:
|
||||||
|
parquet_to_jsonl(dataset, work_dir, src_dir, src_dir)
|
||||||
|
|
||||||
|
# Set up environment variables
|
||||||
|
os.environ["MEMORY"] = f"{memory}"
|
||||||
|
os.environ["SEED"] = f"{seed}"
|
||||||
|
|
||||||
|
# Run the original shuffling and splitting command
|
||||||
|
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
|
||||||
|
run_command(
|
||||||
|
f"ulimit -n 100000 && "
|
||||||
|
f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | "
|
||||||
|
f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}"
|
||||||
|
"; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create validation set and remove lines from chunks
|
||||||
|
validation_file = f"{out_dir}/{dataset}.val{suffix}"
|
||||||
|
for i in range(nchunks):
|
||||||
|
chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}"
|
||||||
|
run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}")
|
||||||
|
run_command(f"sed -i '1,{k_validation}d' {chunk_file}")
|
||||||
|
|
||||||
|
print("All tasks completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("dataset", type=str)
|
||||||
|
parser.add_argument("memory", type=float, default=8)
|
||||||
|
parser.add_argument("--data_dir", type=str, default="data")
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument("--nchunks", type=int, default=32)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks)
|
60
setup/download_tokenizer.py
Normal file
60
setup/download_tokenizer.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
TOKENIZER = {
|
||||||
|
"llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"),
|
||||||
|
"llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"),
|
||||||
|
"gemma": ("google/gemma-2-9b", "tokenizer.model"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None):
|
||||||
|
if tokenizer_name in TOKENIZER:
|
||||||
|
repo_id, filename = TOKENIZER[tokenizer_name]
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
try:
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
local_dir=path_to_save,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
token=api_key if api_key else None,
|
||||||
|
)
|
||||||
|
except HTTPError as e:
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
print(
|
||||||
|
"You need to pass a valid `--hf_token=...` to download private checkpoints."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
from tiktoken import get_encoding
|
||||||
|
|
||||||
|
if "TIKTOKEN_CACHE_DIR" not in os.environ:
|
||||||
|
os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save
|
||||||
|
try:
|
||||||
|
get_encoding(tokenizer_name)
|
||||||
|
except ValueError:
|
||||||
|
print(
|
||||||
|
f"Tokenizer {tokenizer_name} not found. Please check the name and try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("tokenizer_name", type=str)
|
||||||
|
parser.add_argument("tokenizer_dir", type=str, default=8)
|
||||||
|
parser.add_argument("--api_key", type=str, default="")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(
|
||||||
|
tokenizer_name=args.tokenizer_name,
|
||||||
|
path_to_save=args.tokenizer_dir,
|
||||||
|
api_key=args.api_key,
|
||||||
|
)
|
Loading…
Reference in a new issue