Initial commit

This commit is contained in:
Pedro Rodriguez 2024-12-12 15:32:30 -08:00
commit bcc039bb75
86 changed files with 12203 additions and 0 deletions

12
.github/workflows/black.yml vendored Normal file
View file

@ -0,0 +1,12 @@
name: Lint with Black
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: psf/black@stable
with:
version: "24.8.0"

10
.github/workflows/isort.yml vendored Normal file
View file

@ -0,0 +1,10 @@
name: Lint with isort
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: isort/isort-action@master

168
.gitignore vendored Normal file
View file

@ -0,0 +1,168 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.out
figures/
.vscode/
.DS_Store

8
.prettierrc Normal file
View file

@ -0,0 +1,8 @@
{
"overrides": [
{
"files": "*.yaml",
"options": { "tabWidth": 2 }
}
]
}

80
CODE_OF_CONDUCT.md Normal file
View file

@ -0,0 +1,80 @@
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

36
CONTRIBUTING.md Normal file
View file

@ -0,0 +1,36 @@
# Contributing to
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to BLT, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

28
LICENSE Normal file
View file

@ -0,0 +1,28 @@
BSD 3-Clause License
Copyright 2024 Meta
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice,this list
of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may
be used to endorse or promote products derived from this software without specific
prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.

117
README.md Normal file
View file

@ -0,0 +1,117 @@
# Byte Latent Transformer
This repository contains code for our paper: "Byte Latent Transformer: Patches Scale Better Than Tokens"
- [Paper Link](https://dl.fbaipublicfiles.com/blt/BLT__Patches_Scale_Better_Than_Tokens.pdf)
## Abstract
We introduce the Byte Latent Transformer architecture (BLTs), a new byte-level LLM architecture that
for the first time, matches tokenization-based LLM performance at scale, with significant improvements
in inference efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve
as the primary units of computation. Patches are segmented dynamically based on the entropy of the
next byte, allocating more compute and model capacity where there is more data complexity. The BLT
architecture includes new attention mechanisms to maximize the information flow between byte and
patch hidden representations and a new type of byte-sequence memory. We present the first scaling
study of byte-level models up to 8B parameters and 8T training bytes, showing for the first time
that we can train a model end-to-end at scale from bytes with no tokenization or other preprocessing.
Scaling trends reveal training and inference efficiency benefits from dynamically selecting very long
patches on average, along with qualitative improvements with reasoning and long tail generalization
from modeling byte-sequences.
![BLT Architecture Diagram](blt-figure.jpg)
## Development Status
We are actively updating the blt code to make it easier to reproduce our results.
Please file an issue and/or be patient while we make more of our code public!
## Quick start
The following commands launch a SLURM job that creates an environment for Meta Lingua.
The env creation should take around 5 minutes without counting downloads.
```bash
git clone https://github.com/facebookresearch/blt
cd blt
bash setup/create_env.sh
# or if you have access to a SLURM cluster
sbatch setup/create_env.sh
```
Once that is done your can activate the environment
```bash
conda activate blt_<date>
```
use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
```bash
python setup/download_prepare_hf_data.py fineweb_edu <MEMORY> --data_dir ./data --seed 42 --nchunks <NCHUNKS>
```
to download tokenizer (here llama3), use the folowing script:
```bash
python setup/download_tokenizer.py llama3 <SAVE_PATH> --api_key <HUGGINGFACE_TOKEN>
```
Now launch a debug job to check if everything works. **The provided configurations are templates, you need to adapt them for them to work (change `dump_dir`, `data.root_dir`, `data.tokenizer.path`, etc ...)**
```bash
# stool stands for SLURM tool !
python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition=<partition>
# if you want to launch locally you can use torchrun
torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
# or you can also launch on 1 GPU
python -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
```
When using `stool`, if a job crashes, it can be relaunched using sbatch:
```bash
sbatch path/to/dump_dir/submit.slurm
```
## Linting
To lint, run the following command
```
bash dev/lint.sh
```
## Citation
The BLT is partially based on Meta Lingua, so consider citing it in addition to our BLT paper if you re-use our work.
BLT Paper Citation (will be updated to arXiv soon)
```
@article{meta_blt,
author = {Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer},
title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
url = {https://github.com/facebookresearch/blt},
year = {2024}
}
```
Lingua Code
```
@misc{meta_lingua,
author = {Mathurin Videau, Badr Youbi Idrissi, Daniel Haziza, Luca Wehrstedt, Jade Copet, Olivier Teytaud, David Lopez-Paz},
title = {{Meta Lingua}: A minimal {PyTorch LLM} training library},
url = {https://github.com/facebookresearch/lingua},
year = {2024}
}
```
## License
The BLT code is partially based on Meta Lingia.
Meta Lingua is licensed under BSD-3-Clause license. Refer to the LICENSE file in the top level directory.

0
apps/__init__.py Normal file
View file

0
apps/main/__init__.py Normal file
View file

View file

@ -0,0 +1,35 @@
name: "debug_evals"
# ckpt_dir: !!CHANGETHIS!!
# dump_dir: !!CHANGETHIS!!
generator:
max_tokens: 8192
dtype: bf16
temperature: 1.0
top_p: 0.95
harness:
tasks:
- hellaswag
- task: boolq
dataset_kwargs:
trust_remote_code: true
- task: nq_open
num_fewshot: 5
- piqa
- task: social_iqa
dataset_kwargs:
trust_remote_code: true
- triviaqa
- winogrande
- openbookqa
- arc_easy
- arc_challenge
- race
- commonsense_qa
# - coqa
- copa
- gsm8k
- bbh
- mmlu
- mmlu_pro
validation:
max_steps: 1000

View file

@ -0,0 +1,87 @@
# dump_dir: !!!CHANGE_THIS!!!
name: large_lm
steps: 60_000
probe_freq: null
seed: 777
optim:
lr: 3e-3
weight_decay: 0.033
warmup: 5000
lr_min_ratio: 0.000001
clip: 1.0
distributed:
fsdp_type: full_shard
compile: true
model_dtype: bf16
matmul_allow_tf32: false
selective_activation_checkpointing: false
tp_size: 1
model:
dim: 2048
n_layers: 25
n_heads: 16
data:
root_dir: data/shuffled
sources:
dclm_baseline_1.0: 100.0
batch_size: 4
prefetch_size: 1024
seq_len: 4096
n_views: 2
load_async: true
add_bos: true
add_eos: true
tokenizer:
name: tiktoken
path: tokenizers/cl_toplang_128k.tiktoken
profiling:
run: true
mem_warmup: 0
mem_steps: 4
profile_warmup: 100
profile_steps: 4
checkpoint:
dump:
every: 2500
keep: 3
eval:
every: 5000
keep: -1
logging:
freq: 1
async_eval_gpus: 8
eval:
harness:
tasks:
- hellaswag
- task: boolq
dataset_kwargs:
trust_remote_code: true
- piqa
- task: social_iqa
dataset_kwargs:
trust_remote_code: true
- winogrande
- openbookqa
- arc_easy
- arc_challenge
- race
- commonsense_qa
- copa
# - coqa
# - task: nq_open
# num_fewshot: 5
# - triviaqa
validation:
max_steps: 1000
generator:
max_tokens: 16384
dtype: bf16

View file

@ -0,0 +1,95 @@
#python -m lingua.stool config=apps/main/configs/llama2_7B.yaml nodes=32 account=fair_amaia_cw_codegen qos=lowest
# dump_dir: !!!CHANGE_THIS!!!
name: "7b_baseline"
steps: 100_000
grad_acc_steps: 1
probe_freq: 100
seed: 777
optim:
lr: 1.0e-3
weight_decay: 0.1
warmup: 2000
lr_min_ratio: 0.000001
clip: 1.0
distributed:
fsdp_type: full_shard
compile: true
model_dtype: bf16
matmul_allow_tf32: false
selective_activation_checkpointing: false
tp_size: 1
model:
dim: 4096
n_layers: 32
n_heads: 32
rope_theta: 100_000
ffn_dim_multiplier: 1.0
multiple_of: 256
data:
root_dir: data/shuffled
sources:
dclm_baseline_1.0: 1.0
batch_size: 2
prefetch_size: 1024
seq_len: 4096
n_views: 2
load_async: true
tokenizer:
name: tiktoken
path: tokenizers/cl_toplang_128k.tiktoken
profiling:
run: true
mem_warmup: 0
mem_steps: 4
profile_warmup: 100
profile_steps: 4
checkpoint:
dump:
every: 10000
keep: -1
eval:
every: 1000
keep: 3
logging:
freq: 1
async_eval_gpus: 8
eval:
dataset_dir: datasets/eval
harness:
tasks:
- hellaswag
- task: boolq
dataset_kwargs:
trust_remote_code: true
- piqa
- task: social_iqa
dataset_kwargs:
trust_remote_code: true
- winogrande
- openbookqa
- arc_easy
- arc_challenge
- race
- commonsense_qa
# - coqa
- copa
- mmlu
- mmlu_pro
# - task: nq_open
# num_fewshot: 5
# - triviaqa
# - gsm8k
# - bbh
validation:
max_steps: 1000
generator:
max_tokens: 8192
dtype: bf16

354
apps/main/eval.py Normal file
View file

@ -0,0 +1,354 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
import torch
from lingua.args import dump_config
from lingua.data import init_choice_state, setup_sources
from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from omegaconf import OmegaConf
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.distributed import (
DistributedArgs,
dist_mean_dict,
get_global_rank,
get_world_size,
setup_torch_distributed,
)
from bytelatent.transformer import LMTransformer, LMTransformerArgs
from apps.main.generate import (
PackedCausalTransformerGenerator,
PackedCausalTransformerGeneratorArgs,
load_consolidated_model_and_tokenizer,
)
EVAL_FOLDER_NAME = "{:010d}"
logger = logging.getLogger()
@dataclass
class LMHarnessArgs:
tasks: Optional[List[Any]] = None
num_fewshot: Optional[int] = None
device: Optional[str] = None
use_cache: Optional[str] = None
cache_requests: bool = False
rewrite_requests_cache: bool = False
delete_requests_cache: bool = False
limit: Optional[Union[int, float]] = None
bootstrap_iters: int = 100000
check_integrity: bool = False
write_out: bool = False
log_samples: bool = True
system_instruction: Optional[str] = None
apply_chat_template: Union[bool, str] = False
fewshot_as_multiturn: bool = False
gen_kwargs: Optional[str] = None
verbosity: str = "INFO"
predict_only: bool = False
random_seed: int = 0
numpy_random_seed: int = 1234
torch_random_seed: int = 1234
fewshot_random_seed: int = 1234
@dataclass
class ValidationArgs:
max_steps: Optional[int] = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = ""
sources: List[str] = field(default_factory=list) # Other sources to eval on
@dataclass
class EvalArgs:
name: str = "evals"
dump_dir: Optional[str] = None
metric_log_dir: Optional[str] = None
ckpt_dir: str = ""
generator: PackedCausalTransformerGeneratorArgs = field(
default_factory=PackedCausalTransformerGeneratorArgs
)
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
wandb: Optional[Any] = None
global_step: Optional[int] = None # for in-training evaluation
def all_dicts_same(dict_list):
if not dict_list: # Check if the list is empty
return True
# Compare each dictionary to the first one
first_dict = dict_list[0]
return all(d == first_dict for d in dict_list)
class MockAccelerator:
def gather(self, tensor):
l = [torch.zeros_like(tensor) for _ in range(get_world_size())]
torch.distributed.all_gather(l, tensor)
return torch.stack(l)
def wait_for_everyone(self):
torch.distributed.barrier()
# Light wrapper around generator for lm-eval harness
class EvalHarnessLM(LM):
def __init__(self, generator):
super().__init__()
self.generator = generator
self.accelerator = MockAccelerator()
self._rank = get_global_rank()
self._world_size = get_world_size()
self.device = generator.device
def generate_until(self, requests: List[Instance]) -> List[str]:
prompts, gen_args = zip(*[req.args for req in requests])
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
gen_args = gen_args[0]
temperature = gen_args.get("temperature", 0.0)
top_p = gen_args.get("top_p", None)
top_k = gen_args.get("top_k", None)
until = gen_args.get("until", [])
self.generator.temperature = temperature
self.generator.top_p = top_p
self.generator.top_k = top_k
self.generator.until = until
generations, _, _ = self.generator.generate(prompts)
filtered_gen = []
for g in generations:
for e in until:
g = g.replace(e, "")
filtered_gen.append(g)
return filtered_gen
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
prompts, continuations = zip(*[req.args for req in requests])
inputs = [req.args[0] + req.args[1] for req in requests]
max_gen_len = self.generator.max_gen_len
# We temporarily lower max gen len
self.generator.max_gen_len = 1
_, lls, greedy = self.generator.generate(inputs)
results = []
for p, ll, gr in zip(prompts, lls, greedy):
p_len = len(
self.generator.tokenizer.encode(p, add_bos=False, add_eos=False)
)
results.append((ll[p_len:].sum().item(), gr[p_len:].all().item()))
self.generator.max_gen_len = max_gen_len
return results
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
prompts = [req.args[0] for req in requests]
max_gen_len = self.generator.max_gen_len
# We temporarily lower max gen len
self.generator.max_gen_len = 1
_, lls, _ = self.generator.generate(prompts)
results = []
for ll in lls:
results.append((ll.sum().item(),))
self.generator.max_gen_len = max_gen_len
return results
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
srcs = {}
for src in val_args.sources:
path = os.path.join(val_args.root_dir, src)
srcs[path] = 1.0
for src in train_cfg.data.sources:
path = os.path.join(train_cfg.data.root_dir, src)
srcs[path] = 1.0
multi_state = init_choice_state(
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
)
path_to_iter = setup_sources(multi_state)
max_gen_len = generator.max_gen_len
# We temporarily lower max gen len
generator.max_gen_len = 1
all_val_metrics = {}
for src in path_to_iter:
jsonl_iterator = path_to_iter[src]
texts = []
logger.info(f"Running validation on {src}...")
for step, (content, state) in enumerate(jsonl_iterator):
if state["current_iter"] > 0 or (
val_args.max_steps is not None and step >= val_args.max_steps
):
break
content_key = "text" if ("text" in content) else "content"
texts.append(content[content_key])
_, loglikelihood, _ = generator.generate(texts)
metrics = defaultdict(list)
for i, ll in enumerate(loglikelihood):
tmp = ll.sum().item()
metrics["nll"].append(tmp)
metrics["nll_per_token"].append(tmp / len(ll))
metrics["nll_per_char"].append(tmp / len(texts[i]))
metrics["avg_seqlen"].append(len(ll))
for m in metrics:
metrics[m] = sum(metrics[m]) / len(metrics[m])
metrics.update(dist_mean_dict(metrics))
logger.info(f"Validation on {src} done. Metrics: {metrics}")
name = os.path.basename(src)
if name in all_val_metrics:
logger.warning(
f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
)
name = f"{name}_1"
all_val_metrics[name] = metrics
generator.max_gen_len = max_gen_len
return all_val_metrics
def launch_eval(cfg: EvalArgs):
if not torch.distributed.is_initialized():
setup_torch_distributed(DistributedArgs())
if (
Path(cfg.ckpt_dir).exists()
and (Path(cfg.ckpt_dir) / "params.json").exists()
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
):
consolidate_path = Path(cfg.ckpt_dir)
else:
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
if not consolidate_path.exists() and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
consolidate_path = str(consolidate_path)
torch.distributed.barrier()
logger.info("Loading model")
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path,
model_cls=LMTransformer,
model_args_cls=LMTransformerArgs,
)
logger.info("Model loaded")
model.eval()
generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
wrap = EvalHarnessLM(generator)
results = simple_evaluate(wrap, **asdict(cfg.harness))
val_results = None
if cfg.validation:
val_results = eval_on_val(generator, cfg.validation, train_cfg)
if get_global_rank() == 0:
with open(Path(cfg.dump_dir) / "results.json", "w") as f:
f.write(json.dumps(results))
logger.info(f"All evaluation results: {results['results']}")
if val_results is not None:
with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
f.write(json.dumps(val_results))
logger.info(f"All validation results: {val_results}")
if cfg.metric_log_dir and get_global_rank() == 0:
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
logger.info(f"Writing metric logs to {metric_log_path}")
timestamp = {
"created_at": datetime.utcnow().isoformat(),
}
if cfg.global_step is not None:
timestamp["global_step"] = cfg.global_step
print(
json.dumps(timestamp | results["results"]),
file=open(metric_log_path, mode="a"),
flush=True,
)
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
if val_results is not None:
print(
json.dumps(timestamp | val_results),
file=open(val_log_path, mode="a"),
flush=True,
)
del generator
def main():
"""
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
model: LMTransformerArgsgs
@dataclass
class LMTransformerArgsgs:
dim: int
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
or just name=tictac for top level attributes.
The behavior here is as follows:
1. We instantiate EvalArgs with its default values
2. We override those default values with the ones in the provided config file
3. We override the result with the additional arguments provided through command line
For example, if the config is the following
model:
dim: 128
n_layers: 4
and you call eval.py with eval.py model.dim=64
Then the final TrainArgs will have
model:
dim: 64
n_layers: 4
Plus all the default values in EvalArgs dataclass.
"""
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.structured(EvalArgs())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_object(cfg)
launch_eval(cfg)
if __name__ == "__main__":
main()

463
apps/main/generate.py Normal file
View file

@ -0,0 +1,463 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import torch
from lingua.args import dataclass_from_dict
from lingua.tokenizers.abstract_tokenizer import Tokenizer
from lingua.tokenizers.build_tokenizer import build_tokenizer
from omegaconf import OmegaConf
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm
from bytelatent.base_transformer import (
Attention,
causal_mask,
generate_doc_mask_mod,
lengths_to_local_ids,
lengths_to_start_ids,
)
from bytelatent.checkpoint import CONSOLIDATE_NAME
from bytelatent.transformer import LMTransformer, LMTransformerArgs
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def sample_top_k(probs, k):
topk_value, _ = torch.topk(probs, k) # batch_sz x topk
min_value_top_k = topk_value[:, [-1]]
probs[probs < min_value_top_k] = 0.0
probs.div_(probs.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs, num_samples=1)
return next_token
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
shape = logits.shape
logits = logits.flatten(end_dim=-2)
if temperature > 0.0:
probs = torch.softmax(logits / temperature, dim=-1)
if top_p is not None:
next_token = sample_top_p(probs, top_p)
elif top_k is not None:
next_token = sample_top_k(probs, top_k)
else:
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1)
return next_token.view(shape[:-1])
def pack_prompts(prompts: List[int]):
res = []
lengths = []
for i, p in enumerate(prompts):
p = torch.tensor(p, dtype=torch.long)
l = p.size(0)
res.append(p)
lengths.append(l)
lengths = torch.tensor(lengths, dtype=torch.long)
res = torch.cat(res)
return res, lengths
def batch_prompts(prompts, max_elements, lengths=None):
batches = []
current_batch = []
current_count = 0
for i in range(len(prompts)):
prt = prompts[i]
prompt_size = len(prt) if lengths is None else lengths[i]
if current_count + prompt_size <= max_elements:
current_batch.append(prt)
current_count += prompt_size
else:
if current_batch: # Add the current batch to batches
batches.append(current_batch)
# Start a new batch with the current prompt
current_batch = [prt]
current_count = prompt_size
# Add the last batch if it contains any prompts
if current_batch:
batches.append(current_batch)
return batches
class KVCache(nn.Module):
def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device):
super().__init__()
shape = (bsz, seqlen, n_heads, head_dim)
self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device))
self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device))
self.offset = 0
def reset(self):
self.k_cache.zero_()
self.v_cache.zero_()
self.offset = 0
def update(self, k_val, v_val, tok_idx):
# input_pos: [B], k_val: [B, S, H, D]
self.k_cache.index_copy_(1, self.offset + tok_idx, k_val)
self.v_cache.index_copy_(1, self.offset + tok_idx, v_val)
return self.k_cache, self.v_cache
@dataclass
class PackedCausalTransformerGeneratorArgs:
temperature: float = 0.0
top_p: Optional[float] = None
top_k: Optional[float] = None
max_gen_len: int = 512 # Maximum number of tokens to generate
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
max_prompt_len: Optional[int] = None
until: List[str] = field(default_factory=list)
compile_prefilling: bool = False
reduce_generation_overhead: bool = False
show_progress: bool = False
dtype: Optional[str] = "bf16"
device: Optional[str] = "cuda"
class PackedCausalTransformerGenerator:
def __init__(
self,
cfg: PackedCausalTransformerGeneratorArgs,
model: nn.Module,
tokenizer: Tokenizer,
):
"""
This class wraps a causal transformer model with its corresponding tokenizer
and provides an efficient way to pack prompts together and do generation on
the packed sequence.
For example, if we had the prompts "Hello, I am a " and "Initiating calibration "
Then this class will concatenate those sequence (pack them together)
"Hello, I am a Initiating calibration"
And make the necessary attention masks such that a sequence only attends to itself
during prefilling and generation.
This class creates a fixed size cache of size max_tokens or sum of prompt sizes
+ the max number of generated tokens per sequence.
"""
self.model = model
self.tokenizer = tokenizer
self.temperature = cfg.temperature
self.top_p = cfg.top_p
self.top_k = cfg.top_k
self.max_gen_len = cfg.max_gen_len
self.max_tokens = cfg.max_tokens
self.max_prompt_len = cfg.max_prompt_len
self.until = cfg.until
self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
self.device = cfg.device
# Compile if necessary
self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
self.generate_next_token = torch.compile(
self.generate_next_token,
mode="reduce-overhead",
disable=not cfg.reduce_generation_overhead,
)
self.show_progress = cfg.show_progress
self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
self.prefill_doc_id, self.prefill_tok_id = None, None
self.padded_doc_id, self.padded_tok_id = None, None
self.current_doc_id, self.current_tok_id = None, None
self.padded_doc_start = None
self.prefill_mask = None
def clear_cache(self, offset):
for module in self.model.modules():
if isinstance(module, Attention):
if not hasattr(module, "kv_cache"):
module.kv_cache = KVCache(
1,
self.max_tokens,
module.n_kv_heads,
module.head_dim,
self.dtype,
self.device,
)
module.kv_cache.offset = offset
@torch.compiler.disable
def setup_prefilling(self, lengths: torch.Tensor):
# The KV cache is a fixed size tensor of size max_tokens that we need
# to update in order to do correct autoregressive generation.
# Here we will generate token by token but on multiple sequences
# at once. To do so, we need to have an attention mask that makes
# each sequence independent.
# Each sequence will write to its allocated space in the KV Cache.
# We allocate len(seq) + max_gen_len to each sequence in the cache.
# We will generate max_gen_len for each document
padded_lengths = lengths + self.max_gen_len
max_tokens = self.max_tokens or padded_lengths.sum().item()
# The last document might have more padding to fill up to max_tokens
padded_lengths[-1] += max_tokens - padded_lengths.sum()
# This is the start index in the cache for each document
self.padded_doc_start = lengths_to_start_ids(padded_lengths)
# For example with ab--123--cdef--
# this would be 0, 4, 9 if max_gen_len is 2
# We repeat interleave to align with tokens for prefilling
# Ex: ab--123--cdef--
# 000044444999999
prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths)
# This offset will make sure the tokens are written to the
# correct positions in the cache during prefilling
# We either init the cache or clear it by resetting the offset to prefill_offset
self.clear_cache(prefill_offset)
# The prefilling mask looks like the following for
# the two packed sequences ab and 123 : ab123
# Where spaces are empty cache positions
# keys
# ab---123---
# queries a 10000000000
# b 11000000000
# 1 00000100000
# 2 00000110000
# 3 00000111000
# We make sure to skip the empty cache positions
# and only attend to positions within the same sequence
doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths)
self.prefill_mask = create_block_mask(
doc_mask_mod, 1, None, lengths.sum(), max_tokens
)
# This creates the prefilling token ids which look like
# the following for the packed sequence abcdefg1234
# abcdefg1234
# 01234560123
# The token id gives us the position within each sequence
# This is used to compute ROPE and to update the cache
# At each forward pass the current tokens are written to
# offset + tok_id
self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths)
# This creates the padded token and document ids
# which look like the following for the packed sequence ab123
# ab---123--- ab---123---
# padded_doc_id 00000111111 padded_tok_id 01234012345
# This will later be useful for the attention mask at generation
self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths)
@torch.compiler.disable
def setup_generation(self, lengths):
# KV Cache offset is set to the start of the padded documents
for module in self.model.modules():
if isinstance(module, Attention):
module.kv_cache.offset = self.padded_doc_start
# The token ids during generations correspond to the lengths of each doc
# current_tok_id will be incremented during generation
self.current_tok_id = lengths.clone()
# Since we're generating one token per document
# the document id is just an arange
self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device)
# From here on some methods for generation
def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
# Prefilling is done by taking multiple packed sequences and
# doing block diagonal attention on them so they remain independent
self.setup_prefilling(lengths=lengths)
prefill_out = self.model.forward(
tokens,
tok_idx=self.prefill_tok_id,
mask=self.prefill_mask,
attn_impl="flex_attention",
)
self.setup_generation(lengths=lengths)
return prefill_out
def generate_next_token(self, current_token):
# Since we're doing generation with multiple sequences at once
# we need to ignore tokens and cache entries from other sequences
# or in the future.
# Example mask :
# keys
# abc--1234--
# queries c 11100000000
# 4 00000111100
# mask shape : (n_seqs, cache_size)
doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0)
caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0)
mask = doc_mask & caus_mask
out = self.model.forward(
current_token,
tok_idx=self.current_tok_id, # n_seqs
mask=mask,
attn_impl="sdpa",
)
self.current_tok_id += 1
return out
@torch.inference_mode()
def generate(self, prompts):
# Tokenize
prompts = [
self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts
]
# Truncate
max_seqlen = (
self.max_tokens
if not hasattr(self.model, "max_seqlen")
else self.model.max_seqlen
)
max_prompt_len = self.max_prompt_len or min(
max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len
)
prompts = [p[-max_prompt_len:] for p in prompts]
# Account for the generation in lengths
padded_lengths = [len(p) + self.max_gen_len for p in prompts]
generation = []
loglikelihood = []
greedy = []
it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths)
if self.show_progress:
it = tqdm(it)
for batch in it:
n_seqs = len(batch)
generated_tokens = [[] for _ in range(n_seqs)]
is_done = [False for _ in range(n_seqs)]
packed_batch, lengths = pack_prompts(batch)
packed_batch, lengths = packed_batch.cuda(), lengths.cuda()
n_seqs = lengths.size(0)
# Prefilling cache
prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths)
# Selecting last token in each prompt
all_tokens = sample_tokens(
prompt_logits, self.temperature, self.top_p, self.top_k
)
start_token = all_tokens[:, lengths.cumsum(0) - 1]
for seq_id, tok in enumerate(start_token.squeeze(0).tolist()):
generated_tokens[seq_id].append(tok)
current_token = start_token
for i in range(1, self.max_gen_len):
next_logits = self.generate_next_token(current_token)
next_token = sample_tokens(
next_logits.clone(), self.temperature, self.top_p, self.top_k
)
for seq_id, tok in enumerate(next_token.squeeze(0).tolist()):
if not is_done[seq_id]:
generated_tokens[seq_id].append(tok)
current_end_str = self.tokenizer.decode(
generated_tokens[seq_id][-self.max_until_size :]
)
contains_end_string = any(
[e in current_end_str for e in self.until]
)
is_done[seq_id] = (
contains_end_string or tok == self.tokenizer.eos_id
)
if all(is_done):
break
current_token = next_token
generation.extend([self.tokenizer.decode(g) for g in generated_tokens])
for p, logit in zip(
batch, prompt_logits.squeeze(0).split(lengths.tolist())
):
x = logit[:-1]
y = torch.tensor(p[1:], device=x.device)
loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu())
greedy.append((x.argmax(dim=-1) == y).cpu())
return generation, loglikelihood, greedy
def load_consolidated_model_and_tokenizer(
consolidated_path,
model_cls=LMTransformer,
model_args_cls=LMTransformerArgs,
):
ckpt_path = Path(consolidated_path)
config = ckpt_path / "params.json"
config = OmegaConf.load(config)
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
config.distributed.model_dtype
]
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
model = model_cls(model_args)
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
model.load_state_dict(st_dict["model"])
model = model.cuda().eval()
for param in model.parameters():
param.data = param.data.to(dtype=param_dtype)
return model, tokenizer, config
def main():
# Load CLI arguments (overrides) and combine with a YAML config
cfg = OmegaConf.from_cli()
gen_cfg = dataclass_from_dict(
PackedCausalTransformerGeneratorArgs, cfg, strict=False
)
print(cfg)
model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
# Allow multiple prompts
prompts = []
while True:
prompt = input("Enter a prompt (or press enter to finish): ")
if not prompt:
break
prompts.append(prompt)
# Start generation
start_time = time.time()
generation, loglikelihood, greedy = generator.generate(prompts)
end_time = time.time()
# Calculate tokens per second
total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
tokens_per_second = total_tokens / (end_time - start_time)
# Display the results
for i, gen in enumerate(generation):
print(f"\nPrompt {i+1}: {prompts[i]}")
print(f"Generated Text: {gen}")
print(f"\nTokens per second: {tokens_per_second:.2f}")
if __name__ == "__main__":
main()

654
apps/main/lingua_train.py Normal file
View file

@ -0,0 +1,654 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import gc
import logging
import os
import sys
from contextlib import ExitStack
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from pathlib import Path
from timeit import default_timer as timer
from typing import Any, Dict, Optional
import torch
import torch.distributed
import wandb
import xformers.profiler
from lingua.args import dataclass_from_dict, dump_config, flatten_dict
from lingua.data import (
DataArgs,
PackTokensState,
build_dataloader_from_args,
init_dataloader_state_from_args,
)
from lingua.tokenizers.build_tokenizer import TokenizerArgs
from omegaconf import OmegaConf
from pydantic import BaseModel
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import lr_scheduler
from bytelatent.checkpoint import (
CheckpointArgs,
CheckpointManager,
load_from_checkpoint,
)
from bytelatent.distributed import (
DistributedArgs,
EnvironmentArgs,
check_model_value_range,
clean_env,
dist_mean_dict,
get_device_mesh,
get_is_master,
get_world_size,
init_signal_handler,
parallelize_model,
requeue_slurm_job,
setup_env,
setup_torch_distributed,
)
from bytelatent.logger import init_logger
from bytelatent.metrics import (
GPUMemoryMonitor,
LoggingArgs,
MetricLogger,
get_num_params,
)
from bytelatent.optim import OptimArgs, build_optimizer
from bytelatent.probe import AutoProbeD
from bytelatent.profiling import ProfilerArgs, maybe_run_profiler
from bytelatent.stool import StoolArgs, launch_job
from bytelatent.transformer import (
LMTransformer,
LMTransformerArgs,
build_fsdp_grouping_plan,
get_no_recompute_ops,
get_num_flop_per_token,
tp_parallelize,
)
logger = logging.getLogger()
class TrainArgs(BaseModel):
name: str = "lingua"
dump_dir: str = ""
seed: int = 42
# Number of gradient accumulation steps
# Total batch size is batch_size*grad_acc_steps
grad_acc_steps: int = 1
gc_collect_freq: int = 1000
probe_freq: int | None = None
# Nb optimizer steps to take
steps: int = 1000
data: DataArgs
optim: OptimArgs
model: LMTransformerArgs
distributed: DistributedArgs
env: EnvironmentArgs
checkpoint: CheckpointArgs
profiling: ProfilerArgs
logging: LoggingArgs
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
async_eval_gpus: int | None = None
eval: Any | None = None
@dataclass
class TrainState(Stateful):
step: int # Nb of steps taken by the optimizer
acc_step: int # Nb of accumulation steps done since last optimizer step
scheduler: lr_scheduler.LambdaLR
data_loader_state: PackTokensState
def state_dict(self) -> Dict[str, Any]:
return {
"step": self.step,
"acc_step": self.acc_step,
"data_loader_state": self.data_loader_state,
"scheduler": self.scheduler.state_dict(),
}
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
self.acc_step = state_dict["acc_step"]
self.data_loader_state = PackTokensState(**state_dict["data_loader_state"])
self.scheduler.load_state_dict(state_dict["scheduler"])
def validate_train_args(args: TrainArgs, output_size: int):
if args.model.vocab_size < 0:
logger.info(f"Setting model output size to {args.model.vocab_size}")
args.model.vocab_size = output_size
assert (
args.model.vocab_size == output_size
), "Vocab size should be the same as output size"
assert args.dump_dir, "Dump dir not set"
if args.checkpoint.path is None:
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
for source in args.data.sources:
data_path = os.path.join(args.data.root_dir, source)
assert os.path.exists(data_path), f"{data_path} doesn't exist"
if (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
!= get_world_size()
):
assert get_world_size() % args.distributed.dp_shard == 0
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
args.distributed.dp_replicate = (
args.distributed.dp_replicate // args.distributed.tp_size
)
logger.warning(
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
)
assert (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
== get_world_size()
)
if args.distributed.fsdp_type == "no_shard":
assert (
args.distributed.dp_shard == 1
and args.distributed.dp_replicate == get_world_size()
)
args.model.max_seqlen = args.data.seq_len
if args.distributed.tp_size == 1:
logger.warning(
"Tensor parallelism has not been tested for a while, use at your own risk"
)
assert (
args.probe_freq != args.profiling.mem_steps
), "Don't profile during probe step"
assert (
args.probe_freq != args.profiling.profile_steps
), "Don't profile during probe step"
if args.logging.wandb is not None:
args.logging.wandb.name = args.name
if args.probe_freq is not None:
assert (
args.distributed.tp_size == 1
), "Probing not supported with tensor parallelism"
assert (
args.distributed.selective_activation_checkpointing is False
), "Probing not supported with selective activation checkpointing"
preemption_flag = dict(flag=False)
def set_preemption_flag(signum, frame):
logger.warning("Signal handler called with signal " + str(signum))
logger.warning("Preemption ! checkpointing asap and exiting.")
preemption_flag["flag"] = True
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
test = train_state.step % freq == 0
if acc_step is not None:
test = test and (train_state.acc_step == acc_step)
elif acc_freq is not None:
test = test and ((train_state.acc_step % acc_freq) == 0)
return test
def train(args: TrainArgs):
with ExitStack() as context_stack:
tokenizer_args = TokenizerArgs(
name=args.data.name,
init_kwargs=args.data.tokenizer.init_kwargs,
)
tokenizer = tokenizer_args.build()
validate_train_args(
args,
tokenizer.n_words,
)
if get_is_master():
os.makedirs(args.dump_dir, exist_ok=True)
dump_config(args, Path(args.dump_dir) / "config.yaml")
init_logger(Path(args.dump_dir) / "train.log")
init_signal_handler(set_preemption_flag) # For handling preemption signals.
setup_env(args.env)
setup_torch_distributed(args.distributed)
world_mesh = get_device_mesh(args.distributed)
logger.info(f"Starting job: {args.name}")
# build dataloader
# need dp world size and rank
dp_mesh = world_mesh["dp_replicate"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
if args.distributed.dp_shard > 1:
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
dp_degree *= world_mesh["dp_shard"].size()
logger.info(f"Running on dp rank : {dp_rank}")
logger.info(f"Running on dp size : {dp_degree}")
torch.manual_seed(args.seed)
logger.info("Building model")
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
with torch.device("meta"):
model = LMTransformer(args.model)
logger.info("Model is built !")
model_param_count = get_num_params(model)
model = parallelize_model(
model,
world_mesh,
args.model,
args.distributed,
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
tp_parallelize=tp_parallelize,
no_recompute_ops=get_no_recompute_ops(),
)
# Once we shard the model on different gpus we can actually initialize the model
# First we create empty tensors of the correct shapes
model = model.to_empty(device="cuda")
# Then we init the model. Please make sure this function initializes *ALL* parameters
# and buffers, otherwise you will have random values in the unitialized tensors
# which will silently fail (give nan gradients for example)
if args.checkpoint.init_ckpt_path:
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
load_from_checkpoint(
args.checkpoint.init_ckpt_path, model, model_key="model"
) # Put model_key="" if its directly the model checkpoint
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
else:
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
torch.manual_seed(args.model.seed)
model.init_weights()
check_model_value_range(model, range=10.0, std=1.0)
# log model size
logger.info(f"Model size: {model_param_count:,} total parameters")
gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
)
logger.info(f"GPU memory usage: {gpu_memory_monitor}")
# build optimizer after apply parallelisms to the model
optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
data_loader_state = init_dataloader_state_from_args(
args.data, dp_rank, dp_degree
)
train_state = TrainState(
step=0,
acc_step=0,
data_loader_state=data_loader_state,
scheduler=scheduler,
)
checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
checkpoint.load(model, optimizer, train_state, world_mesh)
# Either load from latest checkpoint or start from scratch
if args.probe_freq is not None:
if get_is_master():
os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
torch.distributed.barrier()
probe = AutoProbeD(
model,
(
Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
if (dp_rank % 128 == 0)
else None
),
)
probe_mod = model._orig_mod if args.distributed.compile else model
gc.disable()
# train loop
model.train()
metric_logger = context_stack.enter_context(
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
)
data_loader = context_stack.enter_context(
build_dataloader_from_args(
args.data,
state=train_state.data_loader_state,
)
)
torch_profiler = context_stack.enter_context(
maybe_run_profiler(args.dump_dir, model, args.profiling)
)
nwords_since_last_log = 0
time_last_log = timer()
gc.collect()
while train_state.step < args.steps:
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
train_state.acc_step += 1
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
# get batch
curr_lr = float(optimizer.param_groups[0]["lr"])
data_load_start = timer()
batch, train_state.data_loader_state = next(data_loader)
batch = torch.tensor(
batch,
dtype=torch.long,
)
if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
logger.info("garbage collection")
# we do garbage collection manually otherwise different processes
# run the GC at different times so they slow down the whole pipeline
gc.collect()
input_ids = batch[:, :, 0].cuda()
labels = batch[:, :, 1].cuda()
data_load_time = round(timer() - data_load_start, 4)
nwords_since_last_log += input_ids.numel()
bsz, seqlen = labels.shape
# forward
start_timer = torch.cuda.Event(enable_timing=True)
end_timer = torch.cuda.Event(enable_timing=True)
start_timer.record()
# This is an automatic probe that will compute statistics
# of all linears' inputs, weights and outputs
# along with attention logits and entropy
# both in forward and backward pass
if (args.probe_freq is not None) and every_n_steps(
train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps
):
# Here we do a fake forward and backward pass on a smaller
# batch size to avoid OOM
# This assumes the model has no stateful layers (batch norm..)
assert (
next(probe_mod.parameters()).grad is None
), "Can't probe model if grads are not reset"
with probe:
probe.metadata = {
"it": train_state.step,
"global_step": train_state.step,
"loop": "lingua",
}
# Non compiled model uses roughly 2x memory in our exps
# So we divide bsz by 2 or seqlen by 2
probe_bsz = max(1, bsz // 2)
probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2)
probe_loss = probe_mod(
input_ids[:probe_bsz, :probe_seq],
labels[:probe_bsz, :probe_seq],
)
probe_loss.backward()
# We zero grads to cancel this fake step
optimizer.zero_grad()
assert (
next(probe_mod.parameters()).grad is None
), "Probe model shouldn't have grads at this point"
loss = model(input_ids, labels)
# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
loss = loss / args.grad_acc_steps
# backward on scaled loss to create scaled gradients
loss.backward()
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
)
grad_norm = (
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
).item()
# optimizer step
if train_state.acc_step == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
train_state.step += 1
# updates the scale for next iteration
# training iteration complete
end_timer.record()
torch.cuda.synchronize()
curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
# if profiler is active
if torch_profiler:
xformers.profiler.step()
# log metrics
if every_n_steps(
train_state,
args.logging.freq,
acc_step=None if args.logging.acc_freq else 0,
acc_freq=args.logging.acc_freq,
):
time_delta = timer() - time_last_log
wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
total_acc_steps = (
args.grad_acc_steps * train_state.step + train_state.acc_step
)
tokens_per_gpu = (
total_acc_steps * args.data.batch_size * args.data.seq_len
)
total_tokens = dp_degree * tokens_per_gpu
# This is an estimate and the correct values may change
# if you change the architecture
# Use xformer's analyze profile trace to get actual measurement
FLOPS = (
get_num_flop_per_token(
model_param_count - args.model.vocab_size * args.model.dim,
args.model.n_layers,
args.model.dim,
args.data.seq_len,
)
* wps
)
metrics = flatten_dict(
{
"global_step": train_state.step,
"acc_step": train_state.acc_step,
"speed": {
"wps": wps,
"FLOPS": FLOPS,
"curr_iter_time": curr_iter_time,
"data_load_time": data_load_time,
},
"optim": {
"grad_norm": grad_norm,
"lr": curr_lr,
"total_tokens": total_tokens,
},
"memory": gpu_mem_stats._asdict(),
},
sep="/",
)
to_sync = {}
to_sync["loss/out"] = loss.item()
metrics.update(dist_mean_dict(to_sync))
if get_is_master():
metric_logger.log(metrics)
gpu_memory_monitor.reset_peak_stats()
nwords_since_last_log = 0
time_last_log = timer()
logger.info(
f"step: {train_state.step}"
f" acc: {train_state.acc_step}"
f" loss: {round(loss.item(),4):>7}"
f" grad: {grad_norm:.2e}"
f" flops: {FLOPS:.2e}"
f" wps: {wps:.2e}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5}"
f" lr: {curr_lr:.2e}"
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
f" pow: {gpu_mem_stats.power_draw/1000} W"
)
saved = False
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
saved = checkpoint.save(
model,
optimizer,
train_state,
args,
device_mesh=world_mesh,
)
if args.eval is not None and every_n_steps(
train_state, args.checkpoint.eval.every, acc_step=0
):
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
eval_args = dataclass_from_dict(EvalArgs, args.eval)
eval_args.global_step = train_state.step
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
eval_args.dump_dir = str(
os.path.join(
args.dump_dir,
"evals",
EVAL_FOLDER_NAME.format(train_state.step),
)
)
eval_args.metric_log_dir = args.dump_dir
if args.async_eval_gpus is None:
launch_eval(eval_args)
elif get_is_master():
if wandb.run is not None and args.logging.wandb is not None:
eval_args.wandb = deepcopy(args.logging.wandb)
assert args.async_eval_gpus > 0
logger.info(f"Launching evals on {args.async_eval_gpus} gpus")
with clean_env():
launch_job(
StoolArgs(
asdict(eval_args),
script="apps.main.eval",
copy_code=False,
nodes=args.async_eval_gpus // 8,
qos="lowest",
)
)
if preemption_flag["flag"]:
if not saved:
checkpoint.save(
model,
optimizer,
train_state,
args,
device_mesh=world_mesh,
)
requeue_slurm_job()
sys.exit(0)
if not saved:
checkpoint.save(
model,
optimizer,
train_state,
args,
device_mesh=world_mesh,
)
gc.collect()
def main():
"""
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
model: LMTransformerArgsgs
@dataclass
class LMTransformerArgsgs:
dim: int
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
or just name=tictac for top level attributes.
The behavior here is as follows:
1. We instantiate TrainArgs with its default values
2. We override those default values with the ones in the provided config file
3. We override the result with the additional arguments provided through command line
For example, if the config is the following
model:
dim: 128
n_layers: 4
and you call train.py with train.py model.dim=64
Then the final TrainArgs will have
model:
dim: 64
n_layers: 4
Plus all the default values in TrainArgs dataclass.
"""
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.structured(TrainArgs())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_object(cfg)
train(cfg)
if __name__ == "__main__":
main()

BIN
blt-figure.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
blt-figure.pdf Normal file

Binary file not shown.

BIN
bytelatent/.DS_Store vendored Normal file

Binary file not shown.

3
bytelatent/__init__.py Normal file
View file

@ -0,0 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
class ByteLatentError(Exception):
pass

199
bytelatent/args.py Normal file
View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
from typing import Any
import numpy as np
import yaml
from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
find_and_sanitize_chunks,
)
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.iterators.sampling_iterator import SamplingIterator
from bytelatent.data.iterators.sequence_iterator import (
SequenceIterator,
SequencePackingArgs,
)
from bytelatent.data.patcher import PatcherArgs
from bytelatent.distributed import DistributedArgs, EnvironmentArgs
from bytelatent.metrics import LoggingArgs
from bytelatent.model.blt import ByteLatentTransformerArgs
from bytelatent.optim import OptimArgs
from bytelatent.profiling import ProfilerArgs
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
logger = logging.getLogger()
def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
def distribute_data_to_rank(
*,
dataset_path: str,
preprocess_dir: str,
entropy_model_name: str | None,
arrow_batch_size: int,
rank: int,
world_size: int,
) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size)
n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = []
for chunk_path in dataset_chunks:
for worker_id in range(n_workers_per_chunk):
rank_to_arrow_iterator_params.append(
ArrowFileIterator(
file_path=chunk_path,
worker_id=worker_id,
num_workers=n_workers_per_chunk,
preprocess_dir=preprocess_dir,
dataset_files=None,
entropy_model_name=entropy_model_name,
arrow_batch_size=arrow_batch_size,
)
)
return rank_to_arrow_iterator_params[rank]
class DataloaderArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
root_dir: str | None = None
sources: dict[str, float] = {}
batch_size: int = 2
seq_len: int = 2048
seed: int = 42
add_bos: bool = True
add_eos: bool = True
load_async: bool = True
prefetch_size: int = 64
preprocess_dir: str | None = None
dataset_files: list[str] | None = None
entropy_model_name: str | None = "transformer_100m"
arrow_batch_size: int = 100
buffer_size: int = 64
pad_to_max_length: bool = True
max_encoder_seq_length: int = 12288
enable_byte_ngrams: bool = False
tokenizer_args: TokenizerArgs = TokenizerArgs()
patcher_args: PatcherArgs = PatcherArgs()
def _create_sequence_iterators(
self, rank: int, world_size: int
) -> dict[str, SequenceIterator]:
sequence_packing_args = SequencePackingArgs(
output_seq_len=self.seq_len,
buffer_size=self.buffer_size,
)
source_to_sequence_iterator: dict[str, SequenceIterator] = {}
for dataset_path in self.sources:
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
arrow_iterator = distribute_data_to_rank(
dataset_path=os.path.join(self.root_dir, dataset_path),
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
rank=rank,
world_size=world_size,
)
looping_iterator = LoopingIterator(arrow_iterator)
preprocess_iterator = PreprocessIterator(
looping_iterator,
patcher_args=self.patcher_args,
tokenizer_args=self.tokenizer_args,
)
sequence_iterator = SequenceIterator(
preprocess_iterator,
sequence_packing_args=sequence_packing_args,
rng_state=shuffle_rng_state,
)
source_to_sequence_iterator[dataset_path] = sequence_iterator
return source_to_sequence_iterator
def build_from_rank(
self, rank: int, world_size: int
) -> StatefulIterator[Batch, Any]:
source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size)
weight_rng_state = get_rng_state(self.seed + 1, rank, world_size)
sampling_iterator = SamplingIterator(
rng_state=weight_rng_state,
source_to_weight=self.sources,
source_to_iterator=source_to_sequence_iterators,
)
tokenizer = self.tokenizer_args.build()
packing_args = PackingArgs(
batch_size=self.batch_size,
seq_len=self.seq_len,
pad_id=tokenizer.boe_id,
max_length=self.max_encoder_seq_length,
pad_to_max_length=self.pad_to_max_length,
enable_byte_ngrams=self.enable_byte_ngrams,
)
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
mp_iterator = MultiprocessIterator(
packing_iterator, n_batches_to_prefetch=self.prefetch_size
)
return mp_iterator
class TrainArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str = "lingua"
dump_dir: str = ""
seed: int = 42
# Number of gradient accumulation steps
# Total batch size is batch_size*grad_acc_steps
grad_acc_steps: int = 1
gc_collect_freq: int = 1000
probe_freq: int | None = None
# Nb optimizer steps to take
steps: int = 1000
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
distributed: DistributedArgs = DistributedArgs()
env: EnvironmentArgs = EnvironmentArgs()
checkpoint: CheckpointArgs = CheckpointArgs()
profiling: ProfilerArgs = ProfilerArgs()
logging: LoggingArgs = LoggingArgs()
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
async_eval_gpus: int | None = None
eval: Any | None = None
eval_on_gpus: int | None = None
def dump_to_yaml_file(
self, path: str, log_config: bool = True, sort_keys: bool = True
):
model_dict = self.model_dump(mode="json")
yaml_str = yaml.dump(
model_dict,
allow_unicode=True,
sort_keys=sort_keys,
default_flow_style=False,
)
with open(path, "w") as f:
if log_config:
logger.info("Using the following config for this run:")
logger.info(yaml_str)
f.write(yaml_str)

View file

@ -0,0 +1,585 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from enum import Enum
from typing import Optional, Tuple, Union
import torch
from pydantic import BaseModel
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
BlockMask,
_mask_mod_signature,
flex_attention,
)
from xformers.ops import AttentionBias, fmha
from bytelatent import probe
flex_attention_comp = torch.compile(flex_attention)
class InitStdFactor(Enum):
DISABLED = "disabled" # Init std is divided by 1.0
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
class BaseTransformerArgs(BaseModel):
dim: int = 512
n_layers: int = 8
head_dim: Optional[int] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
ffn_dim_multiplier: Optional[float] = None
multiple_of: int = 256
norm_eps: float = 1e-5
rope_theta: float = 10000.0
init_base_std: Optional[float] = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
max_seqlen: int = 1024
def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
target.flatten(end_dim=-1),
**kwargs,
)
def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
cos, sin = freqs.cos(), freqs.sin()
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
seq_dim (int): Sequence dimension index.
Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= seq_dim < ndim
assert freqs_cis.shape == (
x.shape[seq_dim],
x.shape[-3],
2,
2,
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
shape = [
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
] + [2, 2]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
seq_dim: int,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
freqs_cis = reshape_for_broadcast(
freqs_cis, xq_, seq_dim
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def lengths_to_start_ids(lengths):
doc_start = lengths.cumsum(0)
doc_start = doc_start.roll(1)
doc_start[0] = 0
return doc_start
def lengths_to_local_ids(lengths):
assert lengths.ndim == 1
nb_seqs = lengths.size(0)
total_seqlen = lengths.sum()
# This gives the document id of each token
doc_id = torch.repeat_interleave(lengths)
# Compute document start for each document
doc_start = lengths_to_start_ids(lengths)
# Compute document start for each token
doc_start = doc_start[doc_id]
# Compute the position of each token within each document
tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start
return doc_id, tok_id
def generate_doc_mask_mod(
mask_mod: _mask_mod_signature,
lengths: torch.Tensor,
kv_lengths: Optional[torch.Tensor] = None,
) -> _mask_mod_signature:
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
format.
Args:
mask_mod: The mask mod to apply to the documents
lengths: Lengths of each document
Note:
What is the sequence stacked format? When assembling batches of inputs, we
take multiple sequences and stack them together to form 1 large sequence. We then
use masking to ensure that the attention scores are only applied to tokens within
the same document.
Example:
- Square mask
doc_mask lengths
a a b b b c c 2 3 2
a 1 0 0 0 0 0 0
a 1 1 0 0 0 0 0
b 0 0 1 0 0 0 0
b 0 0 1 1 0 0 0
b 0 0 1 1 1 0 0
c 0 0 0 0 0 1 0
c 0 0 0 0 0 1 1
"""
kv_lengths = kv_lengths if kv_lengths is not None else lengths
q_document_id, q_token_id = lengths_to_local_ids(lengths)
kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths)
q_max_idx = lengths.sum() - 1
kv_max_idx = kv_lengths.sum() - 1
def doc_mask_mod(b, h, q_idx, kv_idx):
q_idx_cap = torch.minimum(q_max_idx, q_idx)
kv_idx_cap = torch.minimum(kv_max_idx, kv_idx)
valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx)
same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap]
q_logical = q_token_id[q_idx_cap]
kv_logical = kv_token_id[kv_idx_cap]
inner_mask = mask_mod(b, h, q_logical, kv_logical)
return same_doc & inner_mask & valid_idx
return doc_mask_mod
# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
class RotaryEmbedding(torch.nn.Module):
"""
RotaryEmbedding Module
"""
def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
super().__init__()
self.theta = theta
self.head_dim = head_dim
self.max_seqlen = max_seqlen
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
persistent=False,
)
def reset_parameters(self):
self.freqs_cis[...] = precompute_freqs_cis(
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
)
def forward(
self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
):
"""
Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
Args:
seqlen (int): Contiguous sequence length
tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
Returns:
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
"""
test = (seqlen is not None) or (tok_idx is not None)
assert test, "Should provide atleast seqlen or tok_idx"
if tok_idx is not None:
return self.freqs_cis[tok_idx]
elif seqlen is not None:
return self.freqs_cis[0:seqlen]
class RMSNorm(nn.Module):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor):
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor):
x = probe.log_stats(x, "resid")
output = self._norm(x.float())
return (output * self.weight.float()).type_as(x)
def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
rope_theta: float,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.rope_theta = rope_theta
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
)
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
# B S D
bsz, seq_len, dim = x.shape
xq = self.wq(x.view_as(x))
xk = self.wk(x.view_as(x))
xv = self.wv(x.view_as(x))
output_shape = xq.shape
# B S D -> B S H D
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
# This condition helps us be easily compatible
# with inference by adding a pluggable KVCache
if hasattr(self, "kv_cache"):
xk, xv = self.kv_cache.update(xk, xv, tok_idx)
xk = repeat_kv(xk, self.heads_per_group, dim=2)
xv = repeat_kv(xv, self.heads_per_group, dim=2)
if attn_impl == "flex_attention":
assert mask is None or isinstance(mask, BlockMask)
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "fmha":
assert mask is None or isinstance(mask, AttentionBias)
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
# This uses B S H D instead of B H S D of pytorch
elif attn_impl == "sdpa":
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
assert mask is None or isinstance(mask, (str, torch.Tensor))
is_causal = (mask == "causal") if isinstance(mask, str) else False
mask = mask if isinstance(mask, torch.Tensor) else None
output = F.scaled_dot_product_attention(
xq,
xk,
xv,
is_causal=is_causal,
attn_mask=mask,
)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
else:
raise NotImplementedError(
f"Attention implementation {attn_impl} not supported"
)
output = self.wo(output.reshape(output_shape))
return output
def reset_parameters(self, init_std=None, factor=1.0):
init_std = init_std or (self.dim ** (-0.5))
for w in [self.wq, self.wk, self.wv]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std / factor,
a=-3 * init_std,
b=3 * init_std,
)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
mp_size: int = 1,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
assert hidden_dim % mp_size == 0
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# B S D
x1 = self.w1(x.view_as(x))
x3 = self.w3(x.view_as(x))
output = self.w2(F.silu(x1) * x3)
return output
def reset_parameters(self, init_std=None, factor=1.0):
in_init_std = init_std or (self.dim ** (-0.5))
out_init_std = init_std or (self.hidden_dim ** (-0.5))
in_init_std = in_init_std
out_init_std = out_init_std / factor
for w in [self.w1, self.w3]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
nn.init.trunc_normal_(
self.w2.weight,
mean=0.0,
std=out_init_std,
a=-3 * out_init_std,
b=3 * out_init_std,
)
class TransformerBlock(nn.Module):
def __init__(self, args: BaseTransformerArgs):
super().__init__()
assert (args.head_dim is not None) or (
args.n_heads is not None
), "Should specify at least head_dim or n_heads"
self.head_dim = args.head_dim or args.dim // args.n_heads
self.n_heads = args.n_heads or args.dim // args.head_dim
self.n_kv_heads = args.n_kv_heads or self.n_heads
assert args.n_heads % self.n_kv_heads == 0
assert args.dim % args.n_heads == 0
self.attention = Attention(
dim=args.dim,
head_dim=self.head_dim,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
rope_theta=args.rope_theta,
)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
h = x + self.attention(
self.attention_norm(x),
freq_cis,
tok_idx=tok_idx,
mask=mask,
attn_impl=attn_impl,
)
out = h + self.feed_forward(self.ffn_norm(h))
return out
def init_weights(self, init_std=None, factor=1.0):
self.attention.reset_parameters(init_std, factor)
self.attention_norm.reset_parameters()
self.feed_forward.reset_parameters(init_std, factor)
self.ffn_norm.reset_parameters()
class BaseTransformer(nn.Module):
def __init__(self, args: BaseTransformerArgs):
super().__init__()
self.dim = args.dim
self.init_base_std = args.init_base_std
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
def forward(
self,
h,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
):
freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
for i, layer in enumerate(self.layers):
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
return h
def reset_parameters(self):
# Either use fixed base std or sqrt model dim
self.rope_embeddings.reset_parameters()
def init_weights(self):
self.reset_parameters()
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(self.init_base_std, factor)

311
bytelatent/checkpoint.py Normal file
View file

@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
import re
from pathlib import Path
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torch.optim.optimizer
from pydantic import BaseModel, ConfigDict
from torch.distributed._tensor import DeviceMesh
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_state_dict,
set_state_dict,
)
from bytelatent.distributed import get_is_master
logger = logging.getLogger("CHECKPOINT")
FOLDER_NAME = "{:010d}"
RE_FOLDER = r"\d{10}"
RE_CKPT = r"__\d_\d\.distcp"
CONSOLIDATE_FOLDER = "consolidated"
CONSOLIDATE_NAME = "consolidated.pth"
CONFIG_NAME = "params.json"
TRAIN_STATE_NAME = "train_state_{:05d}.json"
RE_DIGITS = re.compile(r"\d+")
class SaveEvery(BaseModel):
model_config = ConfigDict(extra="forbid")
every: int = 1000
keep: int = 0
class CheckpointArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dump: SaveEvery = SaveEvery()
eval: SaveEvery = SaveEvery()
path: str | None = None
init_ckpt_path: str | None = None
continue_training_from_init: bool = False
def _get_key_step(name: str):
return int(re.findall(RE_DIGITS, name)[-1])
def consolidate_checkpoints(ckpt_dir: str):
"""
Consolidates all FSDP checkpoints in a directory to a single file
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
Parameters:
ckpt_dir: str - path to the directory containing the checkpoints
Returns the path to the consolidated checkpoint
"""
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not (consolidate_path / CONSOLIDATE_NAME).exists():
consolidate_path.mkdir(exist_ok=True)
logger.info(f"Consolidating to: {str(consolidate_path)}")
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
(consolidate_path / CONFIG_NAME).write_text(
(Path(ckpt_dir) / CONFIG_NAME).read_text()
)
logger.info("Consolidated !")
return consolidate_path
def load_from_checkpoint(
ckpt_dir: str,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
model_key: str = "model",
optim_key: str = "optim",
):
if not (Path(ckpt_dir) / ".metadata").exists():
raise ValueError(
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
)
state_dict = {}
if optimizer is not None:
state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer)
else:
state_dict[model_key] = get_model_state_dict(model)
if model_key == "": # If only loading a model directly, the key should be empty
state_dict = state_dict.pop(model_key)
dcp.load(state_dict, checkpoint_id=ckpt_dir)
class CheckpointManager:
def __init__(self, args: CheckpointArgs):
self.path = args.path
self.dump_every = args.dump
self.eval_every = args.eval
self.init_ckpt_path = args.init_ckpt_path
self.continue_training_from_init = args.continue_training_from_init
assert os.path.exists(
self.path
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
self.existing_saves = self.get_existing_saves()
def get_existing_saves(self) -> List[Path]:
folders = [
p
for p in Path(self.path).iterdir()
if p.is_dir() and re.match(RE_FOLDER, p.name)
]
folders.sort(key=lambda p: _get_key_step(p.name))
return folders
def clean_up(self):
logger.info("Cleaning up checkpoints...")
dump_folders = []
eval_folders = []
other_folders = []
for p in self.existing_saves:
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
if is_dump:
dump_folders.append(p)
if is_eval:
eval_folders.append(p)
if not (is_dump or is_eval):
other_folders.append(p)
logger.info(f"Dump folders: {dump_folders}")
logger.info(f"Eval folders: {eval_folders}")
logger.info(f"Other folders: {other_folders}")
if self.dump_every.keep > 0:
dump_folders = dump_folders[-self.dump_every.keep :]
if self.eval_every.keep > 0:
eval_folders = eval_folders[-self.eval_every.keep :]
folder_to_keep = set(other_folders + dump_folders + eval_folders)
folder_to_remove = set(self.existing_saves) - folder_to_keep
logger.info(f"Removing folders: {folder_to_remove}")
if dist.get_rank() == 0:
for folder in folder_to_remove:
for file in folder.iterdir():
if file.is_file():
file.unlink()
elif file.is_dir():
assert file.name in [CONSOLIDATE_FOLDER]
for f in file.iterdir():
f.unlink()
file.rmdir()
folder.rmdir()
dist.barrier()
self.existing_saves = list(folder_to_keep)
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
path = None
for p in reversed(self.existing_saves):
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
path = p
break
return path
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
folder = base_path / folder_name
if get_is_master():
folder.mkdir(parents=False, exist_ok=True)
if dist.is_initialized():
dist.barrier()
return folder
def _get_dp_tp_mesh(
self, device_mesh: Optional[DeviceMesh] = None
) -> Tuple[int, int]:
dp_rank = 0
tp_rank = 0
if device_mesh is not None:
if "dp_replicate" in device_mesh.mesh_dim_names:
dp_rank = device_mesh.get_local_rank("dp_replicate")
if "dp_shard" in device_mesh.mesh_dim_names:
dp_rank = dp_rank * device_mesh[
"dp_replicate"
].size() + device_mesh.get_local_rank("dp_shard")
if "tp" in device_mesh.mesh_dim_names:
tp_rank = device_mesh.get_local_rank("tp")
return dp_rank, tp_rank
@torch.no_grad()
def get_state_dict(
self,
model,
optimizer,
):
model_sd, optim_sd = get_state_dict(model, optimizer)
return {"model": model_sd, "optim": optim_sd}
def save(
self,
model,
optimizer,
train_state,
config,
device_mesh: Optional[DeviceMesh] = None,
) -> bool:
# When creating directory check if only rank0 or is there other solution
path = Path(self.path)
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
logger.info(f"Saving to: {str(curr_save_dir)}")
if dist.is_initialized():
dist.barrier()
logger.info("Saving...")
state_dict = self.get_state_dict(model, optimizer)
dcp.save(state_dict, checkpoint_id=curr_save_dir)
logger.info("State dict saved!")
if dist.is_initialized():
dist.barrier()
if get_is_master():
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
# Add json dump here
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
if tp_rank == 0:
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info(
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
)
with open(curr_save_dir / train_state_name, "w") as f:
json.dump(train_state.state_dict(), f)
logger.info("Train state saved !")
self.existing_saves.append(curr_save_dir)
self.clean_up()
if dist.is_initialized():
dist.barrier()
return True
@torch.no_grad()
def load(
self,
model: nn.Module,
optimizer,
train_state,
device_mesh: DeviceMesh,
path: Optional[Path] = None,
):
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
path = path or self.get_last_step_path(dp_rank=dp_rank)
# If none of those are available don't do anything
if path is None:
# If no checkpoints exist do nothing
return
# Only load train state if it's provided, the files exist and we're not loading from init path
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info("Reloading train state")
with open(path / train_state_name, "r") as f:
train_state_dict = json.load(f)
train_state.load_state_dict(train_state_dict)
logger.info("Train state reloaded")
logger.info(f"Loading from: {str(path)}")
state_dict = self.get_state_dict(
model=model,
optimizer=optimizer,
)
dcp.load(state_dict, checkpoint_id=path)
logger.info("State dict loaded.")
logger.info("Reloading model and optim")
set_state_dict(
model,
optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
)
logger.info("Model and optim reloaded")
@classmethod
def instantiate_and_make_dir(cls, args: CheckpointArgs):
if get_is_master():
os.makedirs(args.path, exist_ok=True)
dist.barrier()
return cls(args)

View file

@ -0,0 +1,110 @@
# Template config, need to change dump_dir, data.root_dir and tokenizer.path
# Evals can be activated by uncommenting its config
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
dump_dir: /tmp/
name: "debug"
steps: 100_000
probe_freq: null
seed: 777
optim:
lr: 4e-04
warmup: 500
lr_min_ratio: 0.1
clip: 10.0
distributed:
fsdp_type: full_shard
compile: true
model_dtype: bf16
matmul_allow_tf32: false
selective_activation_checkpointing: false
tp_size: 1
model:
n_heads: 8
dim: 512
vocab_size: 260
dim_token: 256
patch_size: 6
tokenization_mode: "bytes"
patching_mode: "space"
tie_local_encoder_decoder_logits: false
data_loader_patching: true
max_encoder_seq_length: 12288
pad_to_max_length: true
patching_threshold: 3.1439168453216553
encoder_hash_byte_group_size: [4]
encoder_hash_byte_group_vocab: 50002
encoder_hash_byte_group_nb_functions: 3
encoder_enable_byte_ngrams: false
cross_attn_encoder: true # assuming cross_attention is true
cross_attn_decoder: true # assuming cross_attention is true
cross_attn_window_encoder: 512
cross_attn_window_decoder: 512
dim_local_encoder: 256
dim_local_decoder: 256
cross_attn_k: 8
cross_attn_nheads: 4
cross_attn_all_layers_decoder: true
cross_attn_all_layers_encoder: true
cross_attn_use_flex_attention: true
cross_attn_init_by_pooling: true
log_patch_lengths: true
non_linearity: "swiglu"
use_rope: true
recompute_fc1_out: false
recompute_fc3_out: false
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
efficient_attn: "sdpa"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512
max_seqlen: 12288
downsampling_by_pooling: "max"
data:
root_dir: ???
sources:
dclm_baseline_1.0: 1.0
batch_size: 2
prefetch_size: 64
seq_len: 4096
load_async: true
preprocess_dir: ???
tokenizer_args:
name: blt
init_kwargs:
bpe_tokenizer_path: ???
profiling:
run: false
checkpoint:
dump:
every: 500
keep: 3
eval:
every: 1000
keep: -1
logging:
freq: 10
eval_on_gpus: 8
eval:
dataset_dir: /checkpoint/amaia/codegen/datasets/eval
tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
generator:
max_tokens: 65536
dtype: bf16
mp_size: 1

5
bytelatent/constants.py Normal file
View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
from pathlib import Path
BLT_DATA = Path(os.environ.get("BLT_DATA", "data"))

View file

@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

View file

@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
from dataclasses import dataclass
from typing import Any, Iterator
import numpy as np
from pydantic import BaseModel, ConfigDict
class BltExample(BaseModel):
model_config = ConfigDict(extra="forbid")
sample_id: str
text: str
tokens: list[int] | None
entropies: list[float] | None
patch_lengths: list[int] | None
mask: list[bool] | None
class MultiChoiceState(BaseModel):
model_config = ConfigDict(extra="forbid")
root_dir: str
sources: dict[str, float]
source_to_state: dict[str, Any]
rng_state: dict[str, Any]
class PrefetchState(BaseModel):
model_config = ConfigDict(extra="forbid")
seq_idx: int
rng_state: dict[str, Any]
prefetch_size: int
batch_size: int
class BltPackTokensState(BaseModel):
model_config = ConfigDict(extra="forbid")
start_token: int
output_seq_len: int
n_views: int = 2
class DataLoaderState(BaseModel):
model_config = ConfigDict(extra="forbid")
multi_choice_state: MultiChoiceState
pack_tokens_state: BltPackTokensState
prefetch_state: PrefetchState
BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
class BltSequence(BaseModel):
tokens: list[int]
mask: list[bool]
patch_lengths: list[int]
@dataclass
class Batch:
x: np.ndarray
y: np.ndarray
mask: np.ndarray | None = None
patch_lengths: np.ndarray | None = None
ngram_ids: np.ndarray | None = None
is_final: bool = False
def to_python_dict(self) -> dict:
x = self.x.tolist()
y = self.y.tolist()
if self.mask is None:
mask = None
else:
mask = self.mask.tolist()
if self.patch_lengths is None:
patch_lengths = None
else:
patch_lengths = self.patch_lengths.tolist()
if self.ngram_ids is None:
ngram_ids = None
else:
ngram_ids = self.ngram_ids.tolist()
return {
"x": x,
"y": y,
"mask": mask,
"patch_lengths": patch_lengths,
"ngram_ids": ngram_ids,
"is_final": self.is_final,
}
@classmethod
def from_python_dict(cls, data: dict) -> "Batch":
x = np.array(data["x"])
y = np.array(data["y"])
if data["mask"] is None:
mask = None
else:
mask = np.array(data["mask"])
if data["patch_lengths"] is None:
patch_lengths = None
else:
patch_lengths = np.array(data["patch_lengths"])
if data["ngram_ids"] is None:
ngram_ids = None
else:
ngram_ids = np.array(data["ngram_ids"])
return Batch(
x=x,
y=y,
mask=mask,
patch_lengths=patch_lengths,
ngram_ids=ngram_ids,
is_final=data["is_final"],
)

View file

@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import abc
from typing import Any, Generator, Generic, TypeVar
T = TypeVar("T")
C = TypeVar("C")
class StatefulIterator(Generic[T, C], abc.ABC):
@abc.abstractmethod
def get_state(self) -> C:
pass
@abc.abstractmethod
def create_iter(self) -> Generator[T, Any, None]:
pass
class IteratorState(Generic[C]):
@abc.abstractmethod
def build(self) -> StatefulIterator[T, C]:
pass

View file

@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import re
from logging import getLogger
from pathlib import Path
from typing import Any, Generator
import pyarrow as pa
# pyarrow needs the initialization from this import
import pyarrow.dataset # pyright: ignore
from pydantic import BaseModel, ConfigDict
from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
logger = getLogger(__name__)
class ArrowFileIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
file_path: str | None
row_num: int
num_workers: int
worker_id: int
preprocess_dir: str | None
dataset_files: list[str] | None
entropy_model_name: str | None
arrow_batch_size: int = 100
def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator(
file_path=self.file_path,
worker_id=self.worker_id,
num_workers=self.num_workers,
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
dataset_files=self.dataset_files,
)
if self.row_num != 0:
arrow_file._set_row_num(self.row_num)
return arrow_file
def shard_sort_key(file: str | Path):
match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
shard_number = int(match.group(1))
return shard_number
class ArrowFileIterator(StatefulIterator):
def __init__(
self,
*,
file_path: str | None,
worker_id: int,
num_workers: int,
preprocess_dir: str | None,
entropy_model_name: str | None,
arrow_batch_size: int,
dataset_files: list[str] | None = None,
):
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None:
raise ByteLatentError("file_path and dataset_files cannot both be None")
self.row_num = 0
self.iter_id = 0
self.batch_iterator = None
self.batch_to_consume = None
self.dataset = None
self.file_path = file_path
self.worker_id = worker_id
self.num_workers = num_workers
self.preprocess_dir = preprocess_dir
self.entropy_model_name = entropy_model_name
self.arrow_batch_size = arrow_batch_size
if dataset_files is None:
# Prepare arrow shards
jsonl_file = Path(file_path)
parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
assert parts is not None
dataset = parts.group(1)
data_dir = Path(preprocess_dir) / dataset / entropy_model_name
shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
for s in shard_files:
if not (data_dir / f"{s.name}.complete").exists():
raise ValueError(f"Missing .complete for input file: {s}")
shard_files = sorted(shard_files, key=shard_sort_key)
if len(shard_files) == 0:
raise ByteLatentError(
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
)
self.dataset_files = [str(f) for f in shard_files]
else:
self.preprocess_dir = None
self.dataset_files = dataset_files
def get_state(self) -> ArrowFileIteratorState:
return ArrowFileIteratorState(
file_path=self.file_path,
row_num=self.row_num,
worker_id=self.worker_id,
num_workers=self.num_workers,
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
dataset_files=self.dataset_files,
)
def create_iter(
self,
) -> Generator[BltExample, Any, None]:
if self.dataset is None:
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
)
self.iter_id += 1
if self.batch_to_consume is not None:
batch_columns: dict[str, list] = self.batch_to_consume
self.batch_to_consume = None
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
text=texts[i],
tokens=None,
mask=None,
patch_lengths=None,
)
self.row_num += 1
if (self.row_num - 1) % self.num_workers == self.worker_id:
yield out
for batch in self.batch_iterator:
batch_columns = batch.to_pydict()
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
text=texts[i],
tokens=None,
mask=None,
patch_lengths=None,
)
self.row_num += 1
if (self.row_num - 1) % self.num_workers == self.worker_id:
yield out
def _set_row_num(self, target_row_num: int):
logger.info(
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
)
if target_row_num is None or target_row_num == 0:
self.row_num = 0
self.dataset = None
self.batch_iterator = None
self.batch_to_consume = None
else:
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
)
curr_remaining = target_row_num
for batch in self.batch_iterator:
if len(batch) > curr_remaining:
batch_columns: dict[str, list] = batch.to_pydict()
batch_columns["sample_id"] = batch_columns["sample_id"][
curr_remaining:
]
batch_columns["entropies"] = batch_columns["entropies"][
curr_remaining:
]
batch_columns["text"] = batch_columns["text"][curr_remaining:]
self.batch_to_consume = batch_columns
break
elif len(batch) == curr_remaining:
# We are exactly at the end of the batch,
# so the next batch is the right spot
break
else:
curr_remaining -= len(batch)
self.row_num = target_row_num
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
)
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
):
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from pydantic import BaseModel
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
ArrowFileIteratorState,
)
class LoopingIteratorState(BaseModel, IteratorState):
file_iterator_state: ArrowFileIteratorState
epoch: int
def build(self) -> "LoopingIterator":
return LoopingIterator(
file_iterator=self.file_iterator_state.build(),
epoch=self.epoch,
)
class LoopingIterator(StatefulIterator):
def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1):
self.file_iterator = file_iterator
self.epoch = epoch
def get_state(self):
return LoopingIteratorState(
file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch
)
def create_iter(self):
while True:
self.epoch += 1
iterator = self.file_iterator.create_iter()
yield from iterator

View file

@ -0,0 +1,243 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import multiprocessing as mp
from multiprocessing.synchronize import Event as EventClass
from queue import Empty, Full
import numpy as np
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import Batch
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
logger = logging.getLogger()
class MultiprocessIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
base_iterator_state: PackingIteratorState
n_batches_to_prefetch: int
serialized_prefetch_buffer: str
def build(self):
base_iterator = self.base_iterator_state.build()
data = json.loads(self.serialized_prefetch_buffer)
prefetch_buffer = [Batch.from_python_dict(item) for item in data]
return MultiprocessIterator(
base_iterator,
n_batches_to_prefetch=self.n_batches_to_prefetch,
prefetch_buffer=prefetch_buffer,
)
def start_work_from_state(
batch_queue: mp.Queue,
state_queue: mp.Queue,
stop_event: EventClass,
state_dumped_event: EventClass,
state: IteratorState,
):
logging.info("Worker thread: Starting base_iterator work")
stateful_iterator = state.build()
iterator = stateful_iterator.create_iter()
for item in iterator:
while not stop_event.is_set():
try:
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
batch_queue.put(item, timeout=0.1)
# On success, stop trying
break
except Full:
pass
if stop_event.is_set():
# Signal the end of output, this ensures that even if the queue takes a while to
# buffer, that the main thread receives everything (and tosses this fake batch)
logging.info(
"Worker thread: Stop event detected, outputting is_final=True batch"
)
batch_queue.put(
Batch(
x=np.zeros((1, 1)),
y=np.zeros((1, 1)),
is_final=True,
mask=None,
patch_lengths=None,
ngram_ids=None,
)
)
break
try:
logging.info("Worker thread: outputting state")
state_queue.put(iterator.get_state(), timeout=1)
logging.info("Worker thread: state dump complete")
state_dumped_event.set()
logging.info("Worker thread: set state_dump_event")
except Full:
raise ValueError(
"Attempted to dump state into the state queue, but it was full"
)
class MultiprocessIterator(StatefulIterator):
"""
Design sketch of the multiprocess iterator:
Given the base_iterator, the only thing we do with this is call get_state()
so that we can pass that through to the background worker process.
The background process will receive this, rebuild the iterator, then start yielding from it.
However, in order to implement MultiprocessIterator.get_state(), we need to be able to accurately get
(1) the state of the iterator in the worker process
(2) the currently buffered items in the Queue
To do this, we use:
- batch_queue: This is the prefetch buffer the worker yields to and the main loop yields from
- state_queue: This size 1 queue will be how the worker sends the iterator state once it has halted iterating.
It must hold the state in addition to the last batch, if the queue was full at the time the stop event is sent.
- stop_iterating_event: Once this is issued from the main loop, the worker will stop iterating and enter cleanup.
During cleanup, the iterator will send the state of the current iterator to the main loop,
in addition to possibly the last batch if the batch_queue was full at the time
- state_dumped_event: When the main loop issues the stop_iterating_event, it will wait until the state_dumped_event to attempt
to get state from the state_queue. It must do this since the worker may take some time to create and send the state.
Once received by the main loop, the main loop can safely store the Queue (plus maybe the last batch) as the prefetch buffer,
get the worker iterator's state, and terminate the background process + delete associated objects.
At this point, calling create_iter() again will bootstrap everything from the stored state and the old iterator will throw an error
since it will not iterate anymore (so the caller must call create_iter() again to get a python iterator).
"""
def __init__(
self,
base_iterator: StatefulIterator,
*,
n_batches_to_prefetch: int,
prefetch_buffer: list | None = None
):
self.base_iterator = base_iterator
self.n_batches_to_prefetch = n_batches_to_prefetch
if prefetch_buffer is None:
prefetch_buffer = []
self.prefetch_buffer = prefetch_buffer
self.batch_queue = None
self.state_queue = None
self.producer = None
self.stop_iterating_event = None
self.state_dumped_event = None
def get_state(self) -> MultiprocessIteratorState:
"""
This is slightly unusual in effectively destroying the current iterator, its necessary
to halt the background process and allow it to write the state to the main loop
in order to not lose data
"""
if self.producer is None:
serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
)
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=serialized_prefetch_buffer,
)
else:
logging.info("Main thread: Sending stop iteration event")
self.stop_iterating_event.set()
logging.info("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
self.prefetch_buffer = []
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
final_batch_received = True
break
self.prefetch_buffer.append(batch)
except Empty:
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received
try:
base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState)
except Empty:
raise ValueError(
"Attempted to get the state, but it was unexpectantly missing"
)
self.base_iterator = base_iterator_state.build()
self.producer.close()
self.producer = None
self.batch_queue = None
self.state_queue = None
self.stop_iterating_event = None
self.state_dumped_event = None
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
)
def create_iter(self):
logging.info("Main thread: Creating MP iterator")
# First yield from the stored prefetch buffer.
if self.prefetch_buffer is not None:
while len(self.prefetch_buffer) > 0:
item = self.prefetch_buffer.pop(0)
yield item
self.prefetch_buffer = None
assert (
self.producer is None
), "Cannot create two parallel iterators at once, call get_state() then remake to have two."
# using mp context manager avoids excessive CPU loading
ctx = mp.get_context("forkserver")
self.batch_queue = ctx.Manager().Queue(maxsize=self.n_batches_to_prefetch)
# We should only ever one state, which is output at the detection of a stop event
self.state_queue = ctx.Manager().Queue(maxsize=1)
self.stop_iterating_event = ctx.Event()
self.state_dumped_event = ctx.Event()
self.producer = mp.Process(
name="blt_data_loader",
target=start_work_from_state,
args=(
self.batch_queue,
self.state_queue,
self.stop_iterating_event,
self.state_dumped_event,
self.base_iterator.get_state(),
),
)
logger.info("Async dataloader started")
self.producer.start()
while True:
if self.producer.exitcode is not None:
raise RuntimeError(
"Data loader quit unexpectedly, real error has been raised previously"
)
try:
batch = self.batch_queue.get(timeout=0.1)
assert isinstance(batch, Batch)
assert (
not batch.is_final
), "is_final should only be used during get_state() being called"
yield batch
except Empty:
pass
if self.producer is None:
raise ValueError(
"Attempted to call this iterator after calling get_state(). You must call create_iter() to make a new iterator instead."
)

View file

@ -0,0 +1,226 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import Batch, BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
class PackingArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
batch_size: int
seq_len: int
pad_id: int
max_length: int | None
pad_to_max_length: bool
enable_byte_ngrams: bool
class PackingIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
sequence_iterator_state: SamplingIteratorState
packing_args: PackingArgs
def build(self) -> "PackingIterator":
return PackingIterator(
sequence_iterator=self.sequence_iterator_state.build(),
packing_args=self.packing_args,
)
def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]):
assert len(mask_seqs) == bs
lens = [len(m) for m in mask_seqs]
if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
return None
assert slen == max(lens) - 1
mask = np.zeros((bs, slen), dtype=bool)
for i, m in enumerate(mask_seqs):
if m is None:
print(
"Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function."
)
raise NotImplementedError
mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:]
return mask
def truncate_batch(
batch: Batch,
max_length: int,
pad_id: int,
pad_to_max_length: bool = False,
*,
enable_byte_ngrams: bool,
):
"""
Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts
and fixing the batch.mask.
batch.patch_lengths has unchanged shape
x,y, and mask may reduce in size
"""
if batch.patch_lengths is None:
return batch
seq_lengths = batch.patch_lengths.sum(axis=1)
max_length_adj = max_length + 1
if np.any(seq_lengths > max_length_adj):
for i in range(batch.x.shape[0]):
if seq_lengths[i] > max_length_adj:
# Find id of patch that tips over max_length + 1
count, j = 0, 0
while count + batch.patch_lengths[i, j] <= max_length_adj:
count += batch.patch_lengths[i, j]
j += 1
# Edit the batch
assert j < batch.patch_lengths.shape[1]
batch.x[i, max_length:] = pad_id
batch.y[i, max_length:] = pad_id
if batch.mask is not None:
batch.mask[i, max_length:] = False
batch.patch_lengths[i, j:] = 0
batch.patch_lengths[i, j] = max_length_adj - count
# Truncate if necessary.
if max_length < batch.x.shape[1]:
batch.x = batch.x[:, :max_length]
batch.y = batch.y[:, :max_length]
if batch.mask is not None:
batch.mask = batch.mask[:, :max_length]
# Right pad to max_length if necessary
elif pad_to_max_length:
if batch.x.shape[1] < max_length:
# NOTE: this has to be done on an actual patch.
non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1
non_zero_indices = np.maximum(0, non_zero_indices)
batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += (
max_length - batch.x.shape[1]
)
# TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader.
x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype)
x[:, : batch.x.shape[1]] = batch.x
batch.x = x
if batch.y.shape[1] < max_length:
y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
y[:, : batch.y.shape[1]] = batch.y
batch.y = y
if batch.mask is not None and batch.mask.shape[1] < max_length:
mask = np.full(
(batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
)
mask[:, : batch.mask.shape[1]] = batch.mask
batch.mask = mask
assert batch.x.shape[1] <= max_length
assert batch.y.shape[1] <= max_length
assert batch.mask is None or batch.mask.shape[1] <= max_length
assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0)
if pad_to_max_length:
assert batch.x.shape[1] == max_length
assert batch.y.shape[1] == max_length
assert batch.mask is None or batch.mask.shape[1] == max_length
if enable_byte_ngrams:
raise NotImplementedError()
# (num_ngram, batch_size, seq_len)
ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x))
assert ngram_ids.shape[2] == batch.x.shape[1]
else:
ngram_ids = None
batch.ngram_ids = ngram_ids
class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
def __init__(
self,
sequence_iterator: StatefulIterator[BltSequence, Any],
*,
packing_args: PackingArgs,
):
self.sequence_iterator = sequence_iterator
self.packing_args = packing_args
def get_state(self):
return PackingIteratorState(
sequence_iterator_state=self.sequence_iterator.get_state(),
packing_args=self.packing_args,
)
def create_iter(self):
sequence_iter = self.sequence_iterator.create_iter()
batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id
seq_len = self.packing_args.seq_len
pad_to_max_length = self.packing_args.pad_to_max_length
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
max_length = self.packing_args.max_length
while True:
tokens: list[list[int]] = []
masks: list[list[bool]] = []
patch_lengths: list[list[int]] = []
for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
_tokens = sequence.tokens
_mask = sequence.mask
_patch_lengths = sequence.patch_lengths
assert len(sequence.patch_lengths) == self.packing_args.seq_len
last_patch_length = 0
if _patch_lengths[0] > 1:
last_patch_length = _patch_lengths[-1]
_patch_lengths[0] -= 1
_patch_lengths = [1] + _patch_lengths[:-1]
tokens.append(_tokens[: len(_tokens) - last_patch_length])
masks.append(_mask[: len(_mask) - last_patch_length])
patch_lengths.append(_patch_lengths)
x_patch_lengths = np.array(patch_lengths)
# pad batch to same length
tok_seq_len = max([len(toks) for toks in tokens]) - 1
x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
y = np.full((batch_size, tok_seq_len), fill_value=pad_id)
for i, tok_seq in enumerate(tokens):
x[i, : len(tok_seq) - 1] = tok_seq[:-1]
y[i, : len(tok_seq) - 1] = tok_seq[1:]
# Adjust patch lengths to match x
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
assert x_patch_lengths.shape == (batch_size, seq_len)
if enable_byte_ngrams:
raise NotImplementedError()
else:
ngram_ids = None
batch = Batch(
x=x,
y=y,
patch_lengths=x_patch_lengths,
ngram_ids=ngram_ids,
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
)
assert (
x_patch_lengths.sum() == x.size + batch_size
), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
assert (
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
assert np.all(
x_patch_lengths[:, 0] == 1
), f"first patch should always be 1, {x_patch_lengths[:, 0]}"
# cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024)
# cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024
# print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}")
truncate_batch(
batch,
max_length=max_length,
pad_id=pad_id,
pad_to_max_length=pad_to_max_length,
enable_byte_ngrams=enable_byte_ngrams,
)
yield batch

View file

@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any, Generator
import torch
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
ArrowFileIteratorState,
)
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class PreprocessIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
add_tokens: bool
add_patches: bool
tokenizer_args: TokenizerArgs
patcher_args: PatcherArgs
def build(self):
arrow_iterator = self.arrow_file_iterator_state.build()
return PreprocessIterator(
arrow_iterator,
patcher_args=self.patcher_args,
tokenizer_args=self.tokenizer_args,
add_tokens=self.add_tokens,
add_patches=self.add_patches,
)
class PreprocessIterator(StatefulIterator):
"""
Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
preprocessing like tokenization and patching
"""
def __init__(
self,
arrow_iterator: ArrowFileIterator,
*,
patcher_args: PatcherArgs,
tokenizer_args: TokenizerArgs,
add_tokens: bool = True,
add_patches: bool = True,
):
self.arrow_iterator = arrow_iterator
self.tokenizer_args = tokenizer_args
self.patcher_args = patcher_args
self.add_tokens = add_tokens
self.add_patches = add_patches
self.tokenizer: BltTokenizer | None = None
self.patcher: Patcher | None = None
def get_state(self) -> PreprocessIteratorState:
"""
The only state to maintain here is from arrow, there
isn't any internal state on this iterator.
"""
return PreprocessIteratorState(
arrow_file_iterator_state=self.arrow_iterator.get_state(),
tokenizer_args=self.tokenizer_args,
patcher_args=self.patcher_args,
add_tokens=self.add_tokens,
add_patches=self.add_patches,
)
def create_iter(self) -> Generator[BltExample, Any, None]:
if self.tokenizer is None and self.add_tokens:
self.tokenizer = self.tokenizer_args.build()
if self.patcher is None and self.add_patches:
self.patcher = self.patcher_args.build()
example_iter = self.arrow_iterator.create_iter()
for example in example_iter:
if self.add_tokens:
tokens = self.tokenizer.encode(example.text)
else:
tokens = example.tokens
if (
self.patcher is not None
and self.patcher.patching_mode == PatchingModeEnum.entropy
):
assert (
example.entropies is not None
), "For patching, entropies cannot be None"
entropies = torch.tensor(example.entropies).unsqueeze(0)
else:
entropies = None
if self.patcher is None:
patch_lengths = None
else:
patch_lengths = self.patcher.patch(
torch.tensor(tokens).unsqueeze(0),
include_next_token=False,
entropies=entropies,
)[0][0].tolist()
yield BltExample(
sample_id=example.sample_id,
text=example.text,
tokens=tokens,
mask=[True] * len(tokens),
patch_lengths=patch_lengths,
entropies=example.entropies,
)

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
class SamplingIteratorState(BaseModel):
model_config = ConfigDict(extra="forbid")
rng_state: dict[str, Any]
source_to_weight: dict[str, float]
source_to_iterator_state: dict[str, SequenceIteratorState]
def build(self) -> "SamplingIterator":
return SamplingIterator(
rng_state=self.rng_state,
source_to_weight=self.source_to_weight,
source_to_iterator={
source: state.build()
for source, state in self.source_to_iterator_state.items()
},
)
class SamplingIterator(StatefulIterator):
def __init__(
self,
*,
rng_state: dict[str, Any],
source_to_weight: dict[str, float],
source_to_iterator: dict[str, StatefulIterator],
):
self.rng = np.random.default_rng()
self.rng.bit_generator.state = rng_state
self.source_to_weight = source_to_weight
self.source_to_iterator = source_to_iterator
def get_state(self) -> SamplingIteratorState:
return SamplingIteratorState(
rng_state=self.rng.bit_generator.state,
source_to_weight=self.source_to_weight,
source_to_iterator_state={
source: iterator.get_state()
for source, iterator in self.source_to_iterator.items()
},
)
def create_iter(self):
n_sources = len(self.source_to_weight)
possible_sources = []
weights = []
for source, w in self.source_to_weight.items():
possible_sources.append(source)
weights.append(w)
source_to_python_iter = {
source: self.source_to_iterator[source].create_iter()
for source in possible_sources
}
while True:
norm_weights = np.array(weights) / np.array(weights).sum()
source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
yield next(source_to_python_iter[source_choice])

View file

@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from logging import getLogger
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.preprocess_iterator import (
PreprocessIterator,
PreprocessIteratorState,
)
logger = getLogger()
class SequencePackingArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
output_seq_len: int
buffer_size: int
class SequenceIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
sequence_packing_args: SequencePackingArgs
preprocess_iterator_state: PreprocessIteratorState
rng_state: dict[str, Any]
def build(self):
preprocess_iterator = self.preprocess_iterator_state.build()
return SequenceIterator(
preprocess_iterator,
sequence_packing_args=self.sequence_packing_args,
rng_state=self.rng_state,
)
class SequenceIterator(StatefulIterator):
def __init__(
self,
preprocess_iterator: PreprocessIterator,
*,
rng_state: dict[str, Any],
sequence_packing_args: SequencePackingArgs,
):
self.preprocess_iterator = preprocess_iterator
self.sequence_packing_args = sequence_packing_args
self.output_seq_len = sequence_packing_args.output_seq_len
self.buffer_size = sequence_packing_args.buffer_size
self.rng = np.random.default_rng()
self.rng.bit_generator.state = rng_state
def get_state(self):
# TODO: need to also perist the current shuffle buffer
return SequenceIteratorState(
sequence_packing_args=self.sequence_packing_args,
preprocess_iterator_state=self.preprocess_iterator.get_state(),
rng_state=self.rng.bit_generator.state,
)
def create_iter(self):
example_iter = self.preprocess_iterator.create_iter()
n_buffer_patches = self.buffer_size * self.output_seq_len
patch_lengths: list[int] = []
tokens: list[int] = []
mask: list[bool] = []
first = True
for example in example_iter:
assert example.tokens is not None
assert example.mask is not None
assert example.patch_lengths is not None
assert len(example.tokens) != 0
assert len(example.mask) != 0
assert len(example.tokens) == len(example.mask)
assert len(example.tokens) == sum(example.patch_lengths)
tokens.extend(example.tokens)
mask.extend(example.mask)
patch_lengths.extend(example.patch_lengths)
while len(patch_lengths) >= n_buffer_patches:
if first:
first = False
logger.info("First buffer complete")
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
self.buffer_size, self.output_seq_len
)
seq_tokens = []
seq_mask = []
start_id = 0
# We fix the number of patches and therefore global steps per batch
# so we have a variable number of tokens we need to account for
for num_tokens in x_patches.sum(axis=-1):
seq_tokens.append(tokens[start_id : start_id + num_tokens])
seq_mask.append(mask[start_id : start_id + num_tokens])
start_id += num_tokens
assert start_id == x_patches.sum()
# Remove what we just added from the buffer
patch_lengths = patch_lengths[n_buffer_patches:]
tokens = tokens[x_patches.sum() :]
mask = mask[x_patches.sum() :]
seq_patch_lengths: list[list[int]] = x_patches.tolist()
assert len(seq_patch_lengths) == self.buffer_size
for idx in self.rng.permutation(len(seq_patch_lengths)):
assert len(seq_patch_lengths[idx]) == self.output_seq_len
assert (
sum(seq_patch_lengths[idx])
== len(seq_tokens[idx])
== len(seq_mask[idx])
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
yield BltSequence(
tokens=seq_tokens[idx],
mask=seq_mask[idx],
patch_lengths=seq_patch_lengths[idx],
)

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import numpy as np
import pyarrow as pa
# pyarrow needs the initialization from this import
import pyarrow.dataset # pyright: ignore
from bytelatent.constants import BLT_DATA
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
ENTROPY_MODEL = "transformer_100m"
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow")
def test_basic_arrow_file():
dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow")
n_head = 1000
head_df = dataset.head(n_head).to_pandas()
initial_state = ArrowFileIteratorState(
file_path=None,
num_workers=1,
worker_id=0,
preprocess_dir=None,
entropy_model_name=ENTROPY_MODEL,
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
)
arrow_file = initial_state.build()
start_state = arrow_file.get_state()
assert start_state.row_num == initial_state.row_num
sample_id = None
for example in arrow_file.create_iter():
sample_id = example.sample_id
assert head_df.iloc[0]["sample_id"] == sample_id
break
assert arrow_file.get_state().row_num == 1
arrow_file = initial_state.build()
for example in arrow_file.create_iter():
assert example.sample_id == sample_id
assert head_df.iloc[0]["sample_id"] == sample_id
break
# Test resume far enough in to be past the batch size of 100
resumed_state = ArrowFileIteratorState(
file_path=None,
num_workers=1,
worker_id=0,
preprocess_dir=None,
entropy_model_name=ENTROPY_MODEL,
dataset_files=[ARROW_TEST_DATA_1],
row_num=251,
arrow_batch_size=100,
)
arrow_file = resumed_state.build()
for example in arrow_file.create_iter():
assert example.sample_id == head_df.iloc[251]["sample_id"]
assert arrow_file.get_state().row_num == 252
break
world_rank = 1
world_size = 4
# Test World Size and Rank
rank_state = ArrowFileIteratorState(
file_path=None,
num_workers=world_size,
worker_id=world_rank,
preprocess_dir=None,
entropy_model_name=ENTROPY_MODEL,
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
)
arrow_file = rank_state.build()
expected_ids = []
for i in range(n_head):
if i % world_size == world_rank:
expected_ids.append(head_df.iloc[i]["sample_id"])
print(len(expected_ids))
i = 0
for example in arrow_file.create_iter():
assert example.sample_id == expected_ids[i]
i += 1
if i >= len(expected_ids):
break

View file

@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import pandas as pd
from pydantic import BaseModel
from bytelatent.constants import BLT_DATA
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class BltTestIteratorState(BaseModel, IteratorState):
position: int
total: int
def build(self):
blt_iter = BltTestIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=f"This is some test {i} text.",
tokens=None,
mask=None,
entropies=None,
patch_lengths=None,
)
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
position: int
total: int
def build(self):
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestWithEntropiesIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
df = pd.read_json("fixtures/tokens_with_entropies.json")
tokens = df["token_ids"].tolist()
entropies = df["entropies"].tolist()
# BOS and EOS
assert len(tokens) == len(text) + 2
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=text,
tokens=tokens,
mask=[True] * len(tokens),
entropies=entropies,
patch_lengths=None,
)
def test_preprocess_iter():
total = 3
tokenizer_args = TokenizerArgs(
name="blt",
init_kwargs={
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
for mode in [
PatchingModeEnum.bpe,
PatchingModeEnum.space,
]:
data_it = BltTestIterator(total)
patcher_args = PatcherArgs(patching_mode=mode)
example_it = PreprocessIterator(
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
)
count = 0
for example in example_it.create_iter():
assert isinstance(example.tokens, list)
assert isinstance(example.tokens[0], int)
# BOS and EOS
assert len(example.tokens) == len(example.text) + 2
assert example.mask is not None
assert len(example.tokens) == len(example.mask)
count += 1
assert count == total
def test_non_entropy_patch_iter():
total = 3
tokenizer_args = TokenizerArgs(
name="blt",
init_kwargs={
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
for mode in [
PatchingModeEnum.bpe,
PatchingModeEnum.space,
]:
patcher_args = PatcherArgs(patching_mode=mode)
data_it = BltTestIterator(total)
example_it = PreprocessIterator(
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
)
count = 0
for example in example_it.create_iter():
assert isinstance(example.patch_lengths, list)
assert isinstance(example.patch_lengths[0], int)
assert len(example.tokens) == sum(example.patch_lengths)
count += 1
assert count == total
def test_entropy_patch_iter():
total = 2
patcher_args = PatcherArgs(
patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
)
tokenizer_args = TokenizerArgs(
name="blt",
init_kwargs={
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
data_it = BltTestWithEntropiesIterator(total)
example_it = PreprocessIterator(
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
)
count = 0
for example in example_it.create_iter():
assert isinstance(example.patch_lengths, list)
assert isinstance(example.patch_lengths[0], int)
assert len(example.tokens) == sum(example.patch_lengths)
count += 1
assert count == total

View file

@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import pickle
from pathlib import Path
import numpy as np
from bytelatent import ByteLatentError
LOOKUP_OFFSET = 4
def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1):
"""
Wrapper function for applying the lookup table to each n-gram.
:param ngram: Array of numbers representing an n-gram.
:param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs.
:param lookup_offset: Offset to add to the lookup result.
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
"""
def apply_lookup_table(ngram):
"""
Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary.
:param ngram: Array of numbers representing an n-gram.
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
"""
# Convert the n-gram to a tuple
ngram_tuple = tuple(ngram)
if ngram_tuple not in ngram_to_idx:
return 0
else:
return ngram_to_idx[ngram_tuple] + lookup_offset
return apply_lookup_table
def get_byte_ngrams_ids(
byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0
):
"""
Generate n-grams from a 2D numpy array.
:param n: The length of each n-gram.
:param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams.
:return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET.
"""
num_rows, num_cols = byte_array.shape
# Create an array to hold the padded version of the original array
padded_array = np.pad(
byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value
)
# Use stride tricks to avoid explicit looping
strided = np.lib.stride_tricks.as_strided
shape = (num_rows, num_cols, n)
strides = padded_array.strides[:2] + (padded_array.strides[1],)
ngrams = strided(padded_array, shape=shape, strides=strides)
ngram_ids = np.apply_along_axis(
apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams
)
assert ngram_ids.shape == byte_array.shape
return ngram_ids
def reload_tables(
ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET
) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]:
"""
Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram,
only load up to the max specified size. Return the actual number of ngrams taken per ngram size.
"""
idx_to_ngram_tables = {}
ngram_to_idx_tables = {}
vocab_sizes = {}
for ngram, size in ngram_to_size.items():
with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f:
# These are already sorted by count
# Value: tuple of: count, ngram, dataset
ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[
"counts"
]
table = [ngram for ngram, _ in ngram_data][:size]
if len(table) != size:
raise ValueError(
f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}"
)
ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)}
actual_size = len(table)
idx_to_ngram_tables[ngram] = table
ngram_to_idx_tables[ngram] = ngram_to_idx
vocab_sizes[ngram] = actual_size + offset
return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
if ngram_to_size_str is None:
return None
ngram_to_size = {}
for entry in ngram_to_size_str.split(","):
ngram, size = entry.split(":")
ngram = int(ngram)
size = int(size)
ngram_to_size[ngram] = size
return ngram_to_size
class NgramProcessor:
def __init__(
self,
ngram_table_dir: str | None = None,
ngram_to_size: dict[int, int] | None = None,
):
if ngram_table_dir is None or ngram_to_size is None:
raise ByteLatentError(
"ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True"
)
(
self.ngram_to_idx_tables,
self.idx_to_ngram_tables,
self.ngram_vocab_sizes,
) = reload_tables(ngram_table_dir, ngram_to_size)
# Lowest to highest ngram
self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys()))
# Although the model might not use all the ngrams, we need the tokenizer
# to produce ngram_ids such that index zero is the 2-gram, later on in
# src.model.megabyte.Megabyte.forward
assert self.ngram_sizes[0] == 2
def encode_single_ngram_table(self, data: np.ndarray, n: int):
"""
Return the n-grams of the input data for a given n
numpy array with ids of shape data.shape
"""
return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0)
def encode_token_ngrams(self, data: np.ndarray):
"""
Return the n-grams of the input data.
output shape: [ids with data.shape for n in self.ngram_sizes]
"""
return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes]

609
bytelatent/data/patcher.py Normal file
View file

@ -0,0 +1,609 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import math
import time
from collections import defaultdict
from enum import Enum
import torch
from pydantic import BaseModel
from torch.nn import functional as F
from bytelatent.distributed import get_local_rank
from bytelatent.entropy_model import load_entropy_model
# from src.slurm import get_local_rank
from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
from bytelatent.tokenizers.constants import BPE_ID, OFFSET
class PatchingModeEnum(str, Enum):
entropy = "entropy"
bpe = "bpe"
bpe_patcher = "bpe_patcher"
space = "space"
class PatcherArgs(BaseModel):
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
patching_device: str = "cuda"
entropy_model_checkpoint_dir: str | None = None
realtime_patching: bool = False
threshold: float = 1.335442066192627
threshold_add: float | None = None
max_patch_length: int | None = None
patch_size: float = 4.5
patching_batch_size: int = 1
data_loader_patching: bool = False
device: str = "cuda"
monotonicity: bool = False
log_time: bool = False
def build(self) -> "Patcher":
return Patcher(self)
def entropy(scores):
"""
scores: [bs, seq_len, vocab]
returns [bs, seq_len]
Computes the entropy for each token in the batch.
Note: uses natural log.
"""
log_probs = F.log_softmax(scores, dim=-1)
probs = torch.exp(log_probs)
p_log_p = log_probs * probs
entropy = -p_log_p.sum(dim=-1)
return entropy
def calculate_entropies(
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
):
"""
tokens: 2D tensor of shape [batch_size, seq_len]
Return 2D tensor of shape [batch_size, seq_len] with entropies for each token.
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
"""
with torch.no_grad():
entropies = []
max_length = getattr(entropy_model, "max_length", 8192)
batch_numel = max_length * patching_batch_size
splits = torch.split(tokens.flatten(), batch_numel)
for split in splits:
pad_size = (max_length - (split.numel() % max_length)) % max_length
pad = torch.zeros(
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
)
split = torch.cat((split, pad), dim=0)
split = split.reshape(-1, max_length)
if device is not None:
split = split.to(device)
assert torch.all(split >= 0) and torch.all(split < 260)
pred, _ = entropy_model(split)
pred = pred.reshape(-1, pred.shape[-1])[
: split.numel() - pad_size, :
] # [batch_size * seq_len, vocab]
pred_entropies = entropy(pred)
entropies.append(pred_entropies)
entropies = torch.cat(entropies, dim=0)
entropies = entropies.reshape(tokens.shape)
return entropies
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
"""
entropies: [bs, seq_len] torch tensor of entropies
t: threshold
returns [bs, seq_len] mask where True indicates the start of a patch
"""
bs, seq_len = entropies.shape
mask = torch.zeros_like(entropies, dtype=torch.bool)
mask[:, 0] = True
# Calculate differences between consecutive elements along the sequence length
differences = entropies[:, 1:] - entropies[:, :-1]
# Calculate conditions for all elements except the first one in each sequence
condition = differences > t
# Update the mask based on the condition
mask[:, 1:] = condition
return mask
def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
"""
entropies: [bs, seq_len] torch tensor of entropies
t: threshold
returns [bs, seq_len] mask where True indicates the start of a patch
"""
bs, seq_len = entropies.shape
mask = torch.zeros_like(entropies, dtype=torch.bool)
mask[:, 0] = True
# Calculate differences between consecutive elements along the sequence length
differences = entropies[:, 1:] - entropies[:, :-1]
# Calculate conditions for all elements except the first one in each sequence
condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1])
# Update the mask based on the condition
mask[:, 1:] = condition
return mask
def patch_start_ids_from_patch_start_mask(patch_start_mask):
bs, trunc_seq_len = patch_start_mask.shape
max_patches = patch_start_mask.sum(dim=1).max()
if max_patches == 0:
patch_start_ids = torch.full(
(bs, trunc_seq_len),
trunc_seq_len,
dtype=torch.long,
device=patch_start_mask.device,
)
else:
patch_ids = (
torch.arange(trunc_seq_len, device=patch_start_mask.device)
.unsqueeze(0)
.repeat(bs, 1)
)
extra_patch_ids = torch.full(
(bs, trunc_seq_len),
trunc_seq_len,
dtype=torch.long,
device=patch_start_mask.device,
)
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
patch_start_mask_padded = torch.cat(
(patch_start_mask, ~patch_start_mask), dim=1
)
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
bs, trunc_seq_len
)[:, :max_patches]
return patch_start_ids
def check_non_zero_after_zero(tensor):
zero_mask = tensor == 0
shifted_mask = torch.cat(
[
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
zero_mask[:, :-1],
],
dim=1,
)
non_zero_after_zero = (tensor != 0) & shifted_mask
return non_zero_after_zero.any()
def patch_lengths_from_start_ids(patch_start_ids, seq_len):
"""
Calculate patch lengths from start ids.
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
the rest are filled to the seq len.
seq_len: ex: 7 length of the sequence
returns the patch lengths:
[1, 6] for the above example.
"""
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
patch_lengths = patch_end_ids - patch_start_ids + 1
assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
return patch_lengths
def find_space_patch_start_ids(tokens):
bs, seq_len = tokens.shape
tokens_no_offset = tokens - OFFSET
patch_end_mask = (
(tokens_no_offset < ord("0"))
| ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A")))
| ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a")))
| ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000))
| (0b1100_0000 <= tokens_no_offset)
)
patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not()
patch_end_mask |= tokens < OFFSET
patch_start_mask = torch.cat(
[
torch.tensor([1, 1], device=tokens.device, dtype=torch.bool)
.unsqueeze(0)
.repeat(bs, 1),
patch_end_mask[:, 1:],
],
dim=1,
)
max_patches = patch_start_mask.sum(dim=1).max()
patch_ids = (
torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1)
)
extra_patch_ids = torch.full(
(bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device
)
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[
:, :max_patches
]
return patch_start_ids
def to_device(entropy_model, device=None):
if device == "cuda":
rank = get_local_rank()
device = f"cuda:{rank}"
entropy_model = entropy_model.to(device)
return entropy_model, device
def model_pred_to_bpe_patching_pred(pred):
_, indices = torch.max(pred, dim=1)
return indices == BPE_ID
def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None):
assert tokens.device == torch.device(
"cpu"
), f"{tokens.device} != cpu expects tokens to be on cpu"
with torch.no_grad():
bpe_patcher_device, device = to_device(
bpe_patcher, device
) # Get entropy model to right rank device.
bpe_patching_mask = []
max_length = getattr(bpe_patcher, "max_length", 8192)
batch_numel = max_length * patching_batch_size
splits = torch.split(tokens.flatten(), batch_numel)
for split in splits:
pad_size = (max_length - (split.numel() % max_length)) % max_length
pad = torch.zeros(
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
)
split = torch.cat((split, pad), dim=0)
split = split.reshape(-1, max_length).to(device)
assert torch.all(split >= 0) and torch.all(split < 260)
pred = bpe_patcher_device(split)
pred_cpu = pred[0].cpu()
pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[
: split.numel() - pad_size, :
] # [batch_size * seq_len, vocab]
bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu)
bpe_patching_mask.append(bpe_patching_pred)
bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0)
bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape)
return bpe_patching_mask
def find_bpe_patcher_patch_start_ids(
tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True
):
bs, seq_len = tokens.shape
first_ids = (
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
.unsqueeze(0)
.repeat(bs, 1)
)
preds_truncation_len = first_ids.shape[1]
token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1]
if token_input.shape[1] >= 1:
patch_start_mask = apply_bpe_patcher(
token_input, bpe_patcher, patching_batch_size, device
)
assert (
patch_start_mask.shape[1]
== tokens.shape[1] + include_next_token - preds_truncation_len
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
patch_start_ids = torch.cat(
(first_ids, patch_start_ids + preds_truncation_len), dim=1
)
else:
patch_start_ids = first_ids
return patch_start_ids
def find_entropy_patch_start_ids(
entropies,
patch_size=None,
threshold=None,
threshold_add=None,
monotonicity=False,
include_next_token=True,
):
"""
Use entropies to find the start ids of each patch.
Use patch_size or threshold to figure out the total number of patches to allocate.
When threshold is not None the number of patches is not constant between
different sequences, but patches can be identified incrementally rather than
decided globally using the entire sequence.
"""
bs, seq_len = entropies.shape[:2]
first_ids = (
torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
.unsqueeze(0)
.repeat(bs, 1)
)
preds_truncation_len = first_ids.shape[
1
] # remove the first preds because they will be start of patches.
entropies = entropies[:, 1:]
if threshold is None:
num_patches = seq_len // patch_size
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
patch_start_ids = patch_start_ids.sort(dim=1).values
else:
# Assumes that there is at least one token going over the threshold
if monotonicity:
patch_start_mask = patch_start_mask_from_entropy_with_monotonicity(
entropies, threshold
)
elif threshold_add is not None and threshold is not None:
patch_start_mask = patch_start_mask_global_and_monotonicity(
entropies, threshold, threshold_add
)
else:
patch_start_mask = entropies > threshold
if not include_next_token:
patch_start_mask = patch_start_mask[:, :-1]
# patch_start_mask[1:] |= tokens[:-1] < OFFSET
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
patch_start_ids = torch.cat(
(first_ids, patch_start_ids + preds_truncation_len), dim=1
)
return patch_start_ids
def rightpad(seq, pad_id, max_len):
return seq + [pad_id] * (max_len - len(seq))
def find_bpe_delim_patch_start_ids(tokens, delim):
ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False)
out = [[0, 1] for _ in range(tokens.shape[0])]
for x, y in ids:
# start is at delim + 1, delim should be the last element in the patch.
out[x.item()].append(y.item() + 1)
max_len = max([len(elt) for elt in out])
out = [rightpad(elt, tokens.shape[1], max_len) for elt in out]
patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device)
return patch_start_ids
def find_lookup_table_start_mask(
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
):
window_size = lookup_table.ndim
# Unfold the tensor to get sliding windows
unfolded = tokens.unfold(1, window_size, 1)
# Gather indices for each dimension
indices = [unfolded[..., i] for i in range(window_size)]
# Access the lookup table using the gathered indices
result = lookup_table[indices]
return result
def find_lookup_table_patch_start_ids(
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
):
bs, seq_len = tokens.shape
first_ids = (
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
.unsqueeze(0)
.repeat(bs, 1)
)
preds_truncation_len = first_ids.shape[1]
window_size = lookup_table.ndim
assert window_size == 2, f"{window_size} != 2"
# output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape
token_input = (
tokens if include_next_token else tokens[:, : -preds_truncation_len + 1]
)
if token_input.shape[1] >= window_size:
patch_start_mask = find_lookup_table_start_mask(
token_input, lookup_table, include_next_token
)
assert (
patch_start_mask.shape[1]
== tokens.shape[1] + include_next_token - preds_truncation_len
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
patch_start_ids = torch.cat(
(first_ids, patch_start_ids + preds_truncation_len), dim=1
)
else:
patch_start_ids = first_ids
return patch_start_ids
def split_large_numbers(lst, m):
new_lst = []
for i in lst:
if i > m:
while i > m:
new_lst.append(m)
i -= m
new_lst.append(i)
else:
new_lst.append(i)
assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
return new_lst
class Patcher:
def __init__(self, patcher_args: PatcherArgs):
self.patcher_args = patcher_args
self.patching_mode = patcher_args.patching_mode
self.realtime_patching = patcher_args.realtime_patching
if self.realtime_patching:
assert (
patcher_args.entropy_model_checkpoint_dir is not None
), "Cannot require realtime patching without an entropy model checkpoint"
entropy_model = load_entropy_model(
patcher_args.entropy_model_checkpoint_dir
)
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
self.entropy_model = entropy_model
else:
self.entropy_model = None
self.threshold = patcher_args.threshold
self.threshold_add = patcher_args.threshold_add
self.max_patch_length = patcher_args.max_patch_length
self.patch_size = patcher_args.patch_size
self.patching_batch_size = patcher_args.patching_batch_size
self.data_loader_patching = patcher_args.data_loader_patching
self.device = patcher_args.device
self.monotonicity = patcher_args.monotonicity
self.log_time = patcher_args.log_time
if self.log_time:
self.log = defaultdict(float)
def patch(
self,
tokens: torch.Tensor,
include_next_token: bool = False,
preds: torch.Tensor | None = None,
entropies: torch.Tensor | None = None,
threshold: float = None,
) -> torch.Tensor:
"""
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
-> output tensor: [batch_size, max_num_patches]
each tensor is processed independently and gets right padded with zeros.
Patching with the following modes:
1. patching_mode = None: static patch size
2. patching_mode = "entropy":
calculate entropy of each token, allocate patches so that the total
number of patches is the same as static patching but choose to begin
patches on tokens where the model is most uncertain (highest entropy).
When threshold is provided, it uses the threshold to decide when to
start a new patch.
3. patching_mode = "space":
use space like tokens to define the patches.
4. patching_mode = "bpe":
use bpe delim tokens to define the patches.
To correctly patch the last token, it may be necessary to include the next token in the patch
lengths calculations. This is controlled by the include_next_token argument.
"""
bs, seq_len = tokens.shape
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
scores = None
# STATIC
if self.patching_mode is None:
patch_lengths = torch.zeros(
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
dtype=tokens.dtype,
device=tokens.device,
).fill_(self.patch_size)
if seq_len_next_tok % self.patch_size != 0:
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
# ENTROPY
elif self.patching_mode == PatchingModeEnum.entropy:
if self.log_time:
s = time.time()
if entropies is not None:
scores = torch.tensor(entropies, dtype=torch.float32)
elif preds is not None:
scores = entropy(preds)
else:
start_entropies = time.time()
scores = calculate_entropies(
tokens,
self.entropy_model,
self.patching_batch_size,
self.device,
)
if self.log_time:
self.log["calculate_entropies"] += time.time() - s
s = time.time()
patch_start_ids = find_entropy_patch_start_ids(
scores,
self.patch_size,
include_next_token=include_next_token,
threshold=threshold if threshold is not None else self.threshold,
threshold_add=self.threshold_add,
monotonicity=self.monotonicity,
)
if self.log_time:
self.log["find_entropy_patch_start_ids"] += time.time() - s
s = time.time()
patch_lengths = patch_lengths_from_start_ids(
patch_start_ids, seq_len_next_tok
)
if self.log_time:
self.log["patch_lengths_from_start_ids"] += time.time() - s
s = time.time()
# BPE
elif self.patching_mode == PatchingModeEnum.bpe:
patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID)
patch_lengths = patch_lengths_from_start_ids(
patch_start_ids, seq_len_next_tok
)
elif self.patching_mode == PatchingModeEnum.bpe_patcher:
patch_start_ids = find_bpe_patcher_patch_start_ids(
tokens,
self.entropy_model,
self.patching_batch_size,
self.device,
include_next_token,
)
patch_lengths = patch_lengths_from_start_ids(
patch_start_ids, seq_len_next_tok
)
# SPACE
elif self.patching_mode == PatchingModeEnum.space:
patch_start_ids = find_space_patch_start_ids(tokens)
patch_lengths = patch_lengths_from_start_ids(
patch_start_ids, seq_len_next_tok
)
else:
raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
# Apply any processing to patch lengths
if self.max_patch_length is not None:
# TODO: avoid going back to a list here.
patch_lengths = [
split_large_numbers(pl, self.max_patch_length)
for pl in patch_lengths.tolist()
]
max_len = max([len(pl) for pl in patch_lengths])
patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
patch_lengths = torch.tensor(
patch_lengths, dtype=tokens.dtype, device=tokens.device
)
assert not check_non_zero_after_zero(patch_lengths)
# Find the last non-zero column index using argmax on a reversed version of the tensor
last_non_zero_col_reversed = (
(patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
)
# Slice the tensor up to the last non-zero column
patch_lengths = patch_lengths[
:, : patch_lengths.shape[1] - last_non_zero_col_reversed
]
assert (
torch.sum(patch_lengths)
== tokens.numel() + include_next_token * tokens.shape[0]
), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
if self.log_time:
self.log["postprocessing_patch_lengths"] += time.time() - s
self.log["tokens"] += patch_lengths.sum().item()
return patch_lengths, scores

478
bytelatent/distributed.py Normal file
View file

@ -0,0 +1,478 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import atexit
import contextlib
import logging
import multiprocessing as mp
import os
import random
import shutil
import signal
import socket
import subprocess
import sys
import tempfile
from dataclasses import asdict, dataclass
from functools import lru_cache, partial, reduce
from itertools import chain
from typing import List, Optional, Tuple, Union
import torch
# for no recompute ops
import xformers.ops
from pydantic import BaseModel, ConfigDict
from torch import distributed as dist
from torch.distributed import ReduceOp
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed._tensor import DTensor
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
from bytelatent.float8 import convert_linears_to_fp8
logger = logging.getLogger()
# for selective AC
default_no_recompute_ops = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.c10d_functional.reduce_scatter_tensor.default,
torch.ops.xformers_flash.flash_fwd.default,
torch.ops.xformers.efficient_attention_forward_cutlass.default,
}
class DistributedArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dp_shard: int = (
1 # In how many shard to split the model weight. Typically number gpu in a node.
)
dp_replicate: int = (
1 # How many times to replicate the model weight. Typically number of nodes.
)
tp_size: int = 1
selective_activation_checkpointing: bool = False
compile: bool = False
fsdp_type: str = "no_shard"
model_dtype: str = "bf16"
float8_recipe: str | None = None
float8_filter: str = r"layers\.[0-9]+\."
matmul_allow_tf32: bool = False
allow_bf16_reduced_precision_reduction: bool = True
detect_anomaly: bool = False
compile_cache_size_limit: int = 8
spawn_method: str = "forkserver"
class EnvironmentArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
# Use GNU openMP (GOMP) instead of Intel OpenMP [Intel Math Kernel Library (MKL)]
MKL_SERVICE_FORCE_INTEL: str = "GNU"
OMP_NUM_THREADS: str = "1"
MKL_NUM_THREADS: str = "1"
# faster intra-node collectives, seems to be a cluster specific flag
ENABLE_INTRA_NODE_COMM: str = "1"
# avoids OOMs with long context
TORCH_NCCL_AVOID_RECORD_STREAMS: str = "1"
# increasing NCCL timeout time before having some NCCL error 22 should give a 16s timeout
NCCL_IB_TIMEOUT: str = "22"
NCCL_DEBUG: str = "INFO"
TORCH_NCCL_ASYNC_ERROR_HANDLING: str = "1"
def get_device_mesh(distributed_args: DistributedArgs):
tp_size = distributed_args.tp_size
dp_replicate = distributed_args.dp_replicate
dp_shard = distributed_args.dp_shard
assert (
dp_replicate * dp_shard * tp_size == get_world_size()
), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
dims = []
names = []
if dp_replicate >= 1:
dims.append(dp_replicate)
names.append("dp_replicate")
if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
dims.append(dp_shard)
names.append("dp_shard")
if tp_size > 1:
dims.append(tp_size)
names.append("tp")
dims = tuple(dims)
names = tuple(names)
return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
tensor = torch.tensor(x).cuda()
dist.all_reduce(tensor, op=ReduceOp.MAX, group=mesh.get_group() if mesh else None)
return tensor
def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
tensor = torch.tensor(x).cuda()
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
return tensor
def dist_mean_dict(x):
r = dict()
for k in x:
r[k] = dist_mean(x[k])
r[k] = r[k].item() if (r[k].dim() == 0) else r[k].tolist()
return r
@lru_cache()
def get_is_torch_run() -> bool:
return os.environ.get("LOCAL_RANK") is not None
@lru_cache()
def get_is_slurm_job() -> bool:
return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
@lru_cache()
def get_global_rank() -> int:
if get_is_torch_run():
return int(os.environ["RANK"])
elif get_is_slurm_job():
return int(os.environ["SLURM_PROCID"])
else:
return 0
@lru_cache()
def get_local_rank() -> int:
if get_is_torch_run():
return int(os.environ["LOCAL_RANK"])
elif get_is_slurm_job():
return int(os.environ["SLURM_LOCALID"])
else:
return 0
@lru_cache()
def get_world_size() -> int:
if get_is_torch_run():
return int(os.environ["WORLD_SIZE"])
elif get_is_slurm_job():
return int(os.environ["SLURM_NTASKS"])
else:
return 1
@lru_cache()
def get_is_master() -> bool:
return get_global_rank() == 0
@lru_cache()
def get_master_port(job_id: int) -> int:
if get_is_torch_run():
return int(os.environ["MASTER_PORT"])
else:
MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000)
rng = random.Random(job_id)
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
@lru_cache()
def get_master_addr() -> str:
if get_is_torch_run():
return os.environ["MASTER_ADDR"]
elif get_is_slurm_job():
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
)
return hostnames.split()[0].decode("utf-8")
else:
return "127.0.0.1"
def setup_env(env_args: EnvironmentArgs):
env_vars = env_args.model_dump()
# When using Triton, it attempts to locate prebuilt kernels in a cache
# located at ~/.triton/cache, but when that's backed by NFS this can fail
# with a "OSError: [Errno 116] Stale file handle" error. If we were to set
# it to a local directory it would belong to the first user who created it
# and it would fail for the job of any other successive user assigned to
# that machine. To avoid all this mess we use a temporary per-process cache.
triton_cache_dir = tempfile.mkdtemp()
atexit.register(shutil.rmtree, triton_cache_dir, ignore_errors=True)
env_vars["TRITON_CACHE_DIR"] = triton_cache_dir
# We change the tmp dir to /scratch in case it's slurm job
# This avoids filling up the host's usually limited tmpfs
# A full tmpfs leads to very slow creation of processes and weird bugs
if get_is_slurm_job():
new_tmp = f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}"
if os.path.exists(new_tmp):
env_vars["TMP_DIR"] = new_tmp
for name, value in env_vars.items():
if os.environ.get(name) != str(value):
os.environ[name] = str(value)
logger.warning(f"WARNING: Setting {name} to {value}")
def setup_torch_distributed(dist_args):
"""
Handle single and multi-GPU / multi-node / SLURM jobs.
Initialize the following variables:
- global_rank
- world_size
"""
mp.set_start_method(dist_args.spawn_method)
with mp.Manager():
pass
local_rank = get_local_rank()
os.environ["RANK"] = str(get_global_rank())
os.environ["WORLD_SIZE"] = str(get_world_size())
os.environ["MASTER_ADDR"] = get_master_addr()
os.environ["MASTER_PORT"] = str(
get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1)))
)
if get_is_torch_run():
logger.info(f"Run launched with torchrun, local rank: {local_rank}")
elif get_is_slurm_job():
logger.info(f"Run launched with slurm, local rank: {local_rank}")
else:
logger.info("Single GPU job")
logger.info(f"ENV: {os.environ}")
# set GPU device
assert 0 <= local_rank < 8
if dist_args.matmul_allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
logger.warning(
f"WARNING: Setting torch.backends.matmul.allow_tf32 to True. This is faster but less accurate."
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
dist_args.allow_bf16_reduced_precision_reduction
)
if torch.cuda.device_count() > 1:
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(init_method="env://", backend="nccl")
torch.autograd.set_detect_anomaly(dist_args.detect_anomaly)
def get_module(module, access_string):
names = access_string.split(sep=".")
return reduce(getattr, names, module)
def set_module(module, access_string, value):
names = access_string.split(sep=".")
parent = reduce(getattr, names[:-1], module)
setattr(parent, names[-1], value)
def default_fsdp_grouping_plan(n_layers: int) -> List[Tuple[str, bool]]:
return [(f"layers.{i}", i < n_layers - 1) for i in range(n_layers)]
def get_default_policy(no_recompute_ops=None):
no_recompute_ops = no_recompute_ops or default_no_recompute_ops
def default_policy(ctx, func, *args, **kwargs):
return (
CheckpointPolicy.MUST_SAVE
if func in no_recompute_ops
else CheckpointPolicy.PREFER_RECOMPUTE
)
return default_policy
@torch.no_grad()
def check_model_value_range(
model: torch.nn.Module, range: float = 1e3, std: float = 1e3
):
for name, param in chain(model.named_parameters(), model.named_buffers()):
if isinstance(param, DTensor):
param = param.to_local()
if param.numel() == 0:
logger.warning(
f"Model parameter {name} is empty, probably because of FSDP sharding"
)
continue
if torch.isnan(param).any() or torch.isinf(param).any():
logger.warning(f"Model parameter {name} contains NaN or Inf")
param_range = param.max() - param.min()
param_std = param.std()
if param_range > range:
logger.warning(
f"Model parameter {name} has a suspiciously large range ({param_range}): please check initialization and init_weights is defined and called"
)
if param_std > std:
logger.warning(
f"Model parameter {name} has a suspiciously large standard deviation ({param_std}): please check initialization and init_weights is defined and called"
)
if (param == 0).all():
logger.warning(
f"Model parameter {name} is all zeros: it might be because of a missing initialization"
)
def init_signal_handler(callable):
"""
Handle signals sent by SLURM for time limit / pre-emption.
"""
signal.signal(signal.SIGUSR2, callable)
logger.warning("Signal handler installed.")
def requeue_slurm_job():
prod_id = int(os.environ["SLURM_PROCID"])
logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
if prod_id == 0 and os.environ.get("LAUNCH_WITH", "") != "DORA":
logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"])
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
else:
logger.warning("Not the master process, no need to requeue.")
sys.exit(0)
@contextlib.contextmanager
def clean_env():
distrib_names = (
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_WORLD_SIZE",
"TORCHELASTIC_RUN_ID",
"DORA_FORCE_DISTRIB",
)
cluster_env = {
x: os.environ.pop(x)
for x in os.environ
if x.startswith(
("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_", "WANDB_")
)
or x in distrib_names
}
try:
yield
finally:
os.environ.update(cluster_env)
def parallelize_model(
model,
device_mesh,
model_args,
distributed_args: DistributedArgs,
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
tp_parallelize=None,
no_recompute_ops=None,
):
if distributed_args.tp_size > 1:
assert (
distributed_args.fsdp_type == "full_shard"
), "Only full shard is supported for TP parallelism"
assert tp_parallelize is not None, "TP plan is required for TP parallelism"
assert (
distributed_args.compile == False
), "Compile is not supported for TP parallelism"
tp_parallelize(model, device_mesh["tp"], model_args, distributed_args)
if distributed_args.float8_recipe is not None:
if distributed_args.tp_size > 1:
raise RuntimeError("float8 is incompatible with tensor-parallelism for now")
model = convert_linears_to_fp8(
model, distributed_args.float8_recipe, distributed_args.float8_filter
)
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
distributed_args.model_dtype
]
if (
distributed_args.fsdp_type == "full_shard"
or distributed_args.fsdp_type == "no_shard"
):
if distributed_args.fsdp_type == "no_shard":
assert (
distributed_args.dp_shard == 1
), "dp_shard must be 1 for no_shard fsdp_type"
assert (
device_mesh["dp_shard"].size() == 1
), "dp_shard must be 1 for no_shard fsdp_type"
fsdp_config = dict(
mp_policy=(
MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
)
),
mesh=(
device_mesh["dp_replicate", "dp_shard"]
if distributed_args.dp_shard > 1
or distributed_args.fsdp_type == "no_shard"
else device_mesh["dp_replicate"]
),
)
if fsdp_grouping_plan is None:
# Assume that the model has list of layers and group around it
fsdp_grouping_plan = default_fsdp_grouping_plan(len(model.layers))
for path, reshard_after_forward in fsdp_grouping_plan:
module = get_module(model, path)
set_module(
model,
path,
fully_shard(
module, **fsdp_config, reshard_after_forward=reshard_after_forward
),
)
model = fully_shard(model, **fsdp_config, reshard_after_forward=True)
else:
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
if distributed_args.selective_activation_checkpointing:
model = checkpoint_wrapper(
model,
context_fn=partial(
create_selective_checkpoint_contexts,
get_default_policy(no_recompute_ops),
),
)
if distributed_args.compile:
torch._dynamo.config.cache_size_limit = (
distributed_args.compile_cache_size_limit
)
model = torch.compile(model)
return model

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import os
import re
import torch
from bytelatent.transformer import LMTransformer, LMTransformerArgs
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
reloaded = json.loads(fr.read())
torch.set_default_dtype(torch.bfloat16)
model_params = reloaded["model"]
entropy_model = LMTransformer(
LMTransformerArgs(
dim=model_params["dim"],
n_layers=model_params["n_layers"],
n_heads=model_params["n_heads"],
max_seqlen=model_params["max_length"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"],
)
)
entropy_model.load_state_dict(
torch.load(state_dict_path, map_location=device), strict=False
)
entropy_model.to(device)
entropy_model = entropy_model.eval()
# no grads for the model:
for param in entropy_model.parameters():
param.requires_grad = False
return entropy_model

152
bytelatent/float8.py Normal file
View file

@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import re
import warnings
from typing import Callable
import torch
# avoid division by zero when calculating scale
EPS = 1e-12
def scale(t, amax_t, dtype_t):
min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
return t_fp8, scale_t
def matmul(
first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias
):
first_fp8, scale_first = scale(first, amax_first, dtype_first)
second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t)
output = torch._scaled_mm(
first_fp8,
second_t_fp8.t(),
scale_a=scale_first,
scale_b=scale_second_t.t(),
bias=bias,
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
return output
@torch._dynamo.allow_in_graph
class Fp8LinearFn(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b_t, bias):
amax_a = a.abs().amax(dim=-1, keepdim=True)
amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
out = matmul(
a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias
)
ctx.a_requires_grad = a.requires_grad
ctx.b_requires_grad = b_t.requires_grad
ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
ctx.save_for_backward(a, b_t, amax_b_t.max())
return out
@staticmethod
def backward(ctx, grad_out):
a, b_t, amax_b = ctx.saved_tensors
if ctx.a_requires_grad:
b = b_t.t().contiguous()
amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
amax_b = amax_b.repeat(b.shape[0], 1)
grad_a = matmul(
grad_out,
amax_grad_out,
torch.float8_e4m3fn,
b,
amax_b,
torch.float8_e4m3fn,
None,
)
else:
grad_a = None
if ctx.b_requires_grad:
grad_b = grad_out.t() @ a
else:
grad_b = None
if ctx.bias_requires_grad:
grad_bias = grad_out.sum(dim=0)
else:
grad_bias = None
return grad_a, grad_b, grad_bias
class Fp8Linear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
out = out.unflatten(0, input.shape[:-1])
return out
def named_replace(
fn: Callable[[torch.nn.Module, str], torch.nn.Module],
module: torch.nn.Module,
name="",
) -> torch.nn.Module:
for child_name, child_module in list(module.named_children()):
full_name = f"{name}.{child_name}" if name else child_name
new_child_module = named_replace(fn, child_module, full_name)
setattr(module, child_name, new_child_module)
module = fn(module, name)
return module
def convert_linears_to_fp8(
root_module: torch.nn.Module, recipe: str, filter: str
) -> torch.nn.Module:
if recipe not in ["rowwise"]:
raise RuntimeError(f"Unknown float8 recipe {recipe!r}")
if recipe == "rowwise" and torch.__version__ < "2.5":
# We need https://github.com/pytorch/pytorch/pull/134781.
warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0")
# Multi-kernel makes Inductor auto-tune between a regular "streaming"-based
# reduction kernel and a "persistent" reduction kernel. Since fp8 has some
# multi-pass steps (e.g., first get amax, then scale), persistent kernels
# should perform better.
torch._inductor.config.triton.multi_kernel = 1
filter_re = re.compile(filter)
def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
return module
if type(module) == torch.nn.Linear:
if recipe == "rowwise":
new_module = Fp8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.weight.dtype,
device=module.weight.device,
)
new_module.weight = module.weight
new_module.bias = module.bias
else:
assert False, recipe
else:
assert False, str(type(module))
return new_module
out = named_replace(replace, root_module)
# Force re-compile everything
torch._dynamo.reset_code_caches()
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
return out

129
bytelatent/logger.py Normal file
View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import math
import sys
import time
from datetime import timedelta
from bytelatent.distributed import get_global_rank, get_is_slurm_job
class LogFormatter(logging.Formatter):
"""
Custom logger for distributed jobs, displaying rank
and preserving indent from the custom prefix format.
"""
def __init__(self):
self.start_time = time.time()
self.rank = get_global_rank()
self.show_rank = not get_is_slurm_job() # srun has --label
def formatTime(self, record):
subsecond, seconds = math.modf(record.created)
curr_date = (
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
+ f".{int(subsecond * 1_000_000):06d}"
)
delta = timedelta(seconds=round(record.created - self.start_time))
return f"{curr_date} - {delta}"
def formatPrefix(self, record):
fmt_time = self.formatTime(record)
if self.show_rank:
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
else:
return f"{record.levelname:<7} {fmt_time} - "
def formatMessage(self, record, indent: str):
content = record.getMessage()
content = content.replace("\n", "\n" + indent)
# Exception handling as in the default formatter, albeit with indenting
# according to our custom prefix
if record.exc_info:
# Cache the traceback text to avoid converting it multiple times
# (it's constant anyway)
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
if content[-1:] != "\n":
content = content + "\n" + indent
content = content + indent.join(
[l + "\n" for l in record.exc_text.splitlines()]
)
if content[-1:] == "\n":
content = content[:-1]
if record.stack_info:
if content[-1:] != "\n":
content = content + "\n" + indent
stack_text = self.formatStack(record.stack_info)
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
if content[-1:] == "\n":
content = content[:-1]
return content
def format(self, record):
prefix = self.formatPrefix(record)
indent = " " * len(prefix)
content = self.formatMessage(record, indent)
return prefix + content
def set_root_log_level(log_level: str):
logger = logging.getLogger()
level: int | str = log_level.upper()
try:
level = int(log_level)
except ValueError:
pass
try:
logger.setLevel(level) # type: ignore
except Exception:
logger.warning(
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
)
logger.setLevel(logging.NOTSET)
def init_logger(
log_file: str | None = None,
*,
name: str | None = None,
level: str = "NOTSET",
):
"""
Setup logging.
Args:
log_file: A file name to save file logs to.
name: The name of the logger to configure, by default the root logger.
level: The logging level to use.
"""
set_root_log_level(level)
logger = logging.getLogger(name)
# stdout: everything
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.NOTSET)
stdout_handler.setFormatter(LogFormatter())
# stderr: warnings / errors and above
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.setLevel(logging.WARNING)
stderr_handler.setFormatter(LogFormatter())
# set stream handlers
logger.handlers.clear()
logger.handlers.append(stdout_handler)
logger.handlers.append(stderr_handler)
if log_file is not None and get_global_rank() == 0:
# build file handler
file_handler = logging.FileHandler(log_file, "a")
file_handler.setLevel(logging.NOTSET)
file_handler.setFormatter(LogFormatter())
# update logger
logger = logging.getLogger()
logger.addHandler(file_handler)

232
bytelatent/metrics.py Normal file
View file

@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import json
import logging
from collections import namedtuple
from dataclasses import asdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Union
import torch
import torch.nn as nn
import wandb
from pydantic import BaseModel, ConfigDict
from bytelatent.distributed import get_is_master
logger = logging.getLogger()
class WandbArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
job_type: str | None = None
dir: str | None = None
project: str | None = None
entity: str | None = None
tags: list | None = None
group: str | None = None
name: str | None = None
notes: str | None = None
config_exclude_keys: list[str] | None = None
config_include_keys: list[str] | None = None
anonymous: str | None = None
mode: str | None = None
allow_val_change: bool | None = None
resume: Union[bool, str] | None = None
force: bool | None = None
tensorboard: bool | None = None
sync_tensorboard: bool | None = None
monitor_gym: bool | None = None
save_code: bool | None = None
id: str | None = None
fork_from: str | None = None
resume_from: str | None = None
class LoggingArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
freq: int = 10 # Log every freq optimizer steps
acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
wandb: WandbArgs | None = None
class MetricLogger:
def __init__(self, outdir: Path, args: Any | None = None):
self.outdir = outdir
self.jsonl_writer = None
self.args = args
def open(self):
if self.jsonl_writer is None:
self.jsonl_writer = open(self.outdir, "a")
if (
self.args is not None
and self.args.logging.wandb is not None
and get_is_master()
):
run = wandb.init(
config=asdict(self.args),
**asdict(self.args.logging.wandb),
)
def log(self, metrics: dict[str, Any]):
if (
self.args is not None
and self.args.logging.wandb is not None
and (wandb.run is not None)
):
wandb.log(metrics, step=metrics["global_step"])
metrics.update({"created_at": datetime.now(timezone.utc).isoformat()})
print(json.dumps(metrics), file=self.jsonl_writer, flush=True)
def close(self):
if self.jsonl_writer is not None:
self.jsonl_writer.close()
self.jsonl_writer = None
def __enter__(self):
self.open()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
GPUMemStats = namedtuple(
"GPUMemStats",
[
"max_active_gib",
"max_active_pct",
"max_reserved_gib",
"max_reserved_pct",
"num_alloc_retries",
"num_ooms",
"power_draw",
],
)
class GPUMemoryMonitor:
"""
Class to monitor GPU memory usage
"""
def __init__(self, device: str = "cuda:0"):
self.device = torch.device(device) # device object
self.device_name = torch.cuda.get_device_name(self.device)
self.device_index = torch.cuda.current_device()
self.device_capacity = torch.cuda.get_device_properties(
self.device
).total_memory
self.device_capacity_gib = self._to_gib(self.device_capacity)
# reset stats, clear cache
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
def _to_gib(self, memory_in_bytes):
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
memory_in_gib = memory_in_bytes / _gib_in_bytes
return memory_in_gib
def _to_pct(self, memory):
return 100 * memory / self.device_capacity
def get_peak_stats(self):
cuda_info = torch.cuda.memory_stats(self.device)
max_active = cuda_info["active_bytes.all.peak"]
max_active_gib = self._to_gib(max_active)
max_active_pct = self._to_pct(max_active)
max_reserved = cuda_info["reserved_bytes.all.peak"]
max_reserved_gib = self._to_gib(max_reserved)
max_reserved_pct = self._to_pct(max_reserved)
num_retries = cuda_info["num_alloc_retries"]
num_ooms = cuda_info["num_ooms"]
power_draw = torch.cuda.power_draw()
if num_retries > 0:
logger.warning(f"{num_retries} CUDA memory allocation retries.")
if num_ooms > 0:
logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
return GPUMemStats(
max_active_gib,
max_active_pct,
max_reserved_gib,
max_reserved_pct,
num_retries,
num_ooms,
power_draw,
)
def reset_peak_stats(self):
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
def __str__(self):
mem_stats = self.get_peak_stats()
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, "
display_str += (
f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak"
)
return f"{display_str}"
def upload_train_to_wandb(
ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True
):
import json
from pathlib import Path
import wandb
from omegaconf import OmegaConf
cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml")
cfg = OmegaConf.to_container(cfg)
if train:
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
with open(Path(ckpt_dir) / "metrics.jsonl") as f:
for l in f:
m = json.loads(l)
wandb.log(m, step=m["global_step"])
wandb.finish()
if eval:
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f:
for l in f:
m = json.loads(l)
wandb.log(
{
f"evals/{name.replace('/','.')}": value
for name, value in m.items()
if "/" in name
},
step=m["global_step"],
)
wandb.finish()
def get_num_params(model: nn.Module) -> int:
"""
Get the total model params
Args : only_trainable: whether to only count trainable params
"""
numel = {n: p.numel() for n, p in model.named_parameters()}
return sum(numel.values())

View file

@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

1064
bytelatent/model/blt.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,356 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from typing import List, Optional, Tuple, Union
import torch
import torch.nn
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
InitStdFactor,
RMSNorm,
RotaryEmbedding,
TransformerBlock,
)
from bytelatent.model.transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
logger = logging.getLogger()
class LocalModelBase(nn.Module):
def __init__(self, args):
super().__init__()
self.dim = args.dim
self.dropout = args.dropout
self.vocab_size = args.vocab_size + args.pm_size
self.patch_size = args.patch_size
self.efficient_attn = args.efficient_attn
self.sliding_window = args.sliding_window
self.use_rope = args.use_rope
self.init_std_factor = args.init_std_factor
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
self.cross_attn_k = getattr(args, "cross_attn_k", None)
self.boe_id = BOE_ID
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.layers = nn.ModuleList(
[TransformerBlock(args) for _ in range(args.n_layers)]
)
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
if not self.use_rope:
self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
else:
self.rope = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
)
self.pos_embeddings = None
self.token_embedding_projection = (
nn.Linear(args.dim_token_emb, args.dim, bias=False)
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
else None
)
self.patch_embedding_projection = self._create_patch_projection(args)
def _should_create_patch_projection(self, args):
dimension_mismatch = (
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
)
# Check cross attention conditions
cross_attn_conditions = (
hasattr(args, "cross_attn_encoder")
and args.cross_attn_encoder
and getattr(args, "cross_attn_init_by_pooling")
) or (
hasattr(args, "cross_attn_decoder")
and args.cross_attn_decoder
and getattr(args, "cross_attn_init_by_pooling")
)
return dimension_mismatch or cross_attn_conditions
def _create_patch_projection(self, args):
if not self._should_create_patch_projection(args):
return None
output_dim = args.dim_token_emb * (self.cross_attn_k or 1)
return nn.Linear(
in_features=args.dim_patch_emb,
out_features=output_dim,
bias=False,
)
def apply_embedding(self, tokens, embeds):
if embeds is not None:
return embeds
else:
return self.tok_embeddings(tokens)
def init_weights(self, init_std=None):
self.rope.reset_parameters()
init_std = init_std or (self.dim ** (-0.5))
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if self.pos_embeddings is not None:
nn.init.trunc_normal_(
self.pos_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(init_std, factor)
if self.token_embedding_projection is not None:
nn.init.trunc_normal_(
self.token_embedding_projection.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if self.patch_embedding_projection is not None:
nn.init.trunc_normal_(
self.patch_embedding_projection.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if hasattr(self, "output"):
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if self.cross_attn_layers is not None:
for depth, layer in enumerate(self.cross_attn_layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(init_std, factor)
class LocalEncoder(LocalModelBase):
def __init__(self, args):
super().__init__(args)
self.output_proj = (
args.patching_mode in ["entropy", "probmax"]
) and args.entropy_model_checkpoint_dir is None
self.apply_transformer = args.use_local_encoder_transformer
self.downsampling_by_pooling = args.downsampling_by_pooling
self.patch_only = args.patch_only_encoder
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
self.cross_attn_encoder = args.cross_attn_encoder
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
self.cross_attn_nheads = args.cross_attn_nheads
if self.cross_attn_encoder:
self.cross_attn_layers = torch.nn.ModuleList()
layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
for _ in range(layers_to_add):
self.cross_attn_layers.append(
CrossAttention(
dim=self.dim,
head_dim=self.dim // self.cross_attn_nheads,
n_heads=self.cross_attn_nheads,
n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps,
)
)
def apply_embedding(self, tokens, embeds):
if embeds is not None:
assert (
self.expects_hash_embeddings
), "Not expecting embeddings to be passed."
return embeds
else:
return self.tok_embeddings(tokens)
def forward(
self,
tokens: torch.Tensor,
embeds: Optional[torch.Tensor] = None,
patch_embeds: Optional[torch.Tensor] = None,
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
cross_mask: Optional[torch.Tensor] = None,
num_patches: Optional[int] = None,
patch_ids: Optional[torch.Tensor] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
""" """
bs, seqlen = tokens.shape
if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
h = self.apply_embedding(tokens, embeds)
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers):
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
# check if cross attention should be applied to either all layer or only the last layer
if self.cross_attn_encoder and (
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
):
patch_embeds = self.apply_cross_attention(
h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask
)
h_residual = patch_embeds if self.cross_attn_encoder else None
return (h, h_residual), cache
def apply_cross_attention(
self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask
):
# apply pooling and project
if self.cross_attn_init_by_pooling and patch_embeds is None:
patch_embeds = downsample(
h,
num_patches,
patch_ids=patch_ids,
downsampling_by_pooling=self.downsampling_by_pooling,
patch_size=self.patch_size,
)
if self.patch_embedding_projection is not None:
patch_embeds = self.patch_embedding_projection(patch_embeds)
patch_embeds = patch_embeds.reshape(
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
)
layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0
patch_embeds_cross = self.cross_attn_layers[layer_idx](
x=patch_embeds,
kv=h,
mask=cross_mask,
)
patch_embeds += patch_embeds_cross
return patch_embeds
class LocalDecoder(LocalModelBase):
def __init__(self, args):
super().__init__(args)
# Model configuration flags
self.patch_only = args.patch_only_decoder
self.expects_embeddings = args.share_encoder_decoder_emb
self.cross_attn_decoder = args.cross_attn_decoder
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
self.cross_attn_nheads = args.cross_attn_nheads
if self.cross_attn_decoder:
self.cross_attn_layers = torch.nn.ModuleList()
layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
for _ in range(layers_to_add):
self.cross_attn_layers.append(
CrossAttention(
dim=self.dim,
head_dim=self.dim // self.cross_attn_nheads,
n_heads=self.cross_attn_nheads,
n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps,
)
)
self.output = nn.Linear(
self.dim,
args.vocab_size,
bias=False,
)
def forward(
self,
tokens: torch.Tensor,
embeds: Optional[torch.Tensor],
patch_embeds: Optional[torch.Tensor] = None,
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
cross_mask: Optional[torch.Tensor] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
bs, seqlen = tokens.shape
assert embeds is not None, "Embeddings must be provided"
if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
h = embeds
if self.patch_embedding_projection is not None:
assert patch_embeds is not None, "Patch embeddings must be passed."
patch_embeds = self.patch_embedding_projection(patch_embeds)
if self.cross_attn_k is not None:
patch_embeds = patch_embeds.reshape(
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
)
if patch_embeds is not None and not self.cross_attn_decoder:
h = h + patch_embeds
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers):
if self.cross_attn_decoder and (
i == 0 or self.cross_attn_all_layers_decoder
):
# Use cross attention to extract info from patch_embeds into h
h_cross = self.cross_attn_layers[i](
x=h,
kv=patch_embeds,
mask=cross_mask,
)
h = h + h_cross
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
h_preds = self.norm(h)
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
h_preds = self.output(h_preds)
h_preds = h_preds.float()
return h_preds, cache

View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from typing import List, Optional, Tuple, Union
import torch
import torch.nn
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformer,
RMSNorm,
flex_attention_comp,
repeat_kv,
)
from bytelatent.model.utils import create_causal_mask
logger = logging.getLogger()
class CrossAttention(nn.Module):
"""
CrossAttention block to attend to the encoder states from the decoder.
Rope is not supported.
"""
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
)
def forward(
self,
x: torch.Tensor,
kv: torch.Tensor,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
) -> torch.Tensor:
# B S D
bsz, seq_len, _ = x.shape
_, slen_kv, _ = kv.shape
x = self.cross_attn_norm_q(x)
kv = self.cross_attn_norm_kv(kv)
xq = self.wq(x)
xk = self.wk(kv)
xv = self.wv(kv)
output_shape = xq.shape
# B S D -> B S H D
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
xk = repeat_kv(xk, self.heads_per_group, dim=2)
xv = repeat_kv(xv, self.heads_per_group, dim=2)
assert mask is None or isinstance(mask, BlockMask)
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
output = self.wo(output.reshape(output_shape))
return x + output
def init_weights(self, base_std: float, factor: float = 1.0):
std = base_std * factor
nn.init.trunc_normal_(
self.wq.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
nn.init.trunc_normal_(
self.wk.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
nn.init.trunc_normal_(
self.wv.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
output_std = std / (2**0.5)
nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=output_std,
a=-3 * output_std,
b=3 * output_std,
)
self.cross_attn_norm_q.reset_parameters()
self.cross_attn_norm_kv.reset_parameters()
class GlobalTransformer(BaseTransformer):
def __init__(self, args):
super().__init__(args)
self.dropout = args.dropout
self.sliding_window = args.sliding_window
self.efficient_attn = args.efficient_attn
self.token_embedding_projection = None
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
self.token_embedding_projection = nn.Linear(
args.dim_token_emb,
args.dim,
bias=False,
)
def forward(
self,
tokens: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
embeds: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
"""
Similar to BaseTransformer.forward, but with an additional embeds argument
and projection to the token space.
"""
bs, seqlen = tokens.shape
attn_impl = self.efficient_attn
h = embeds
mask = (
mask
if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
)
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
h = self.token_embedding_projection(h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
return h, cache
def init_weights(self, init_base_std: float):
super().init_weights()
if self.token_embedding_projection is not None:
nn.init.trunc_normal_(
self.token_embedding_projection.weight,
mean=0.0,
std=init_base_std,
a=-3 * init_base_std,
b=3 * init_base_std,
)

116
bytelatent/model/utils.py Normal file
View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
from torch.nn.attention.flex_attention import create_block_mask
from xformers.ops import fmha
def patch_reduce(h, max_num_patches, reduction, patch_ids):
"""
Reduce variable length patches to single embedding per patch
Note: this works with variable number of patches for different sequences in the batch
It handles variable length patches by assuming that patch_lengths will be 0 for any
extra patches on the *right*. Since there can be a variable number of patches
this function also return the number of patches for each sequence in the batch.
Any embeddings on the right that are not allocated to a patch
(i.e. if the sum(patch_lengths[i]) < seq_len for any i)
will be sent to a dummy patch, which is trimmed before returning.
"""
bs, seq_len, emb_dim = h.shape
patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
reduced_embs = torch.zeros(
(bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device
)
reduced_embs = reduced_embs.scatter_reduce(
src=h,
dim=1,
index=patch_ids,
reduce=reduction,
include_self=False,
)
reduced_embs = reduced_embs[:, :max_num_patches, :]
return reduced_embs
def concat_downsample(h, patch_lengths, patch_size):
# The assumption in this function is that seq_len = patch_size * num_patches.
bs, seq_len, emb_dim = h.shape
patch_end_ids = torch.cumsum(patch_lengths, dim=1)
patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to(
patch_end_ids.device
)
# Is clamp ok here?
patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1])
patch_ids = patch_ids.view(bs, -1, emb_dim)
# after gather h.shape = [batch_size, seq_len, dim]
h = torch.gather(h, 1, patch_ids)
h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1))
return h
def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids):
cat = []
if "avg" in pooling_mode or "mean" in pooling_mode:
cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids))
if "min" in pooling_mode:
cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids))
if "max" in pooling_mode:
cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids))
assert len(cat) > 0
h = torch.cat(cat, dim=-1)
return h
def downsample(
h,
num_patches,
patch_lengths=None,
patch_ids=None,
downsampling_by_pooling=None,
patch_size=4,
):
"""
Downsampling:
a. concatenating embeddings in the patch
Note: with dynamic patching, patch the last patch_size tokens.
b. pooling embeddings in the patch
"""
# input: h.shape = [batch_size, seq_len, dim]
# input: pool h.shape = [batch_size, seq_len / patch_size, dim]
# if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep
if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0:
# By pooling
max_num_patches = num_patches
assert patch_ids is not None
h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids)
else:
# TODO: remove this condition
# By concatenating (fixed lengths patching)
assert patch_lengths is not None
h = concat_downsample(h, patch_lengths, patch_size)
return h
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
elif attn_impl == "sdpa":
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
elif attn_impl == "fmha":
return None
else:
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)

162
bytelatent/optim.py Normal file
View file

@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import math
from functools import partial
from pydantic import BaseModel, ConfigDict
from torch import nn
from torch.optim import AdamW, lr_scheduler
logger = logging.getLogger()
class OptimArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
lr: float = 3e-4
weight_decay: float = 0.1
epsilon: float = 1e-8
beta1: float = 0.9
beta2: float = 0.95
clip: float = 1.0
scheduler: str = "cosine"
warmup: int = 2000
lr_min_ratio: float = 0.1
cycle_length: float = 1.0
cosine_theta: float = 1.0
annealing_step: int = 1000
decay_fraction: float = 0.1
exp_factor: float = 0.5
def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float:
if step < warmup:
lr = float(step) / warmup
elif step <= n_steps:
s = float(step - warmup) / (n_steps - warmup)
lr = s * min_ratio + (1 - s)
else:
lr = min_ratio
return lr
def lr_inv_sqrt(step: int, warmup: int, exp_factor: float, min_ratio: float) -> float:
if step < warmup:
lr = float(step) / warmup
else:
lr = max((warmup**exp_factor) / (step**exp_factor), min_ratio)
return lr
def lr_cosine(
step: int,
warmup: int,
n_steps: int,
cycle_length: float,
theta: float,
min_ratio: float,
) -> float:
sign = ((step // (n_steps * cycle_length)) % 2) * -2 + 1
if step < warmup:
lr = float(step) / warmup
elif step <= n_steps:
s = float(step - warmup) / (n_steps - warmup)
lr = min_ratio + 0.5 * (1 - min_ratio) * (
sign * math.cos(math.pi * s**theta / cycle_length) + 1
)
else:
lr = min_ratio
return lr
def lr_wsd(
step: int,
warmup: int,
n_steps: int,
decay_fraction: float,
cycle_length: float,
min_ratio: float,
) -> float:
"""
UNDERSTANDING WARMUP-STABLE-DECAY LEARNING RATES: A RIVER VALLEY LOSS LANDSCAPE PERSPECTIVE
https://arxiv.org/pdf/2410.05192
"""
cycle_num = step // int(n_steps * cycle_length) + 1
curr_n_steps = int(n_steps * cycle_length) * cycle_num
decay_length = int(curr_n_steps * decay_fraction)
if step < warmup:
lr = float(step) / warmup
elif step <= curr_n_steps - decay_length:
lr = 1.0
elif step > curr_n_steps - decay_length and step <= curr_n_steps:
# Linear interpolation gives similar results
# slope = -(1.0 - min_ratio) / decay_length
# intercept = min_ratio + ((1.0 - min_ratio) * curr_n_steps) / decay_length
# lr = slope * step + intercept
step = step - (curr_n_steps - decay_length)
lr = 1 / ((step / curr_n_steps) * (1 / min_ratio) + (1 - step / curr_n_steps))
else:
lr = min_ratio
return lr
def build_lr_fn(args: OptimArgs, n_steps: int):
if args.scheduler == "constant":
lr_fn = lambda x: 1.0
elif args.scheduler == "linear":
lr_fn = partial(
lr_linear, warmup=args.warmup, n_steps=n_steps, min_ratio=args.lr_min_ratio
)
elif args.scheduler == "inv_sqrt":
lr_fn = partial(
lr_inv_sqrt,
warmup=args.warmup,
exp_factor=args.exp_factor,
min_ratio=args.lr_min_ratio,
)
elif args.scheduler == "cosine":
lr_fn = partial(
lr_cosine,
warmup=args.warmup,
n_steps=n_steps,
cycle_length=args.cycle_length,
theta=args.cosine_theta,
min_ratio=args.lr_min_ratio,
)
elif args.scheduler == "wsd":
assert args.decay_fraction < args.cycle_length
lr_fn = partial(
lr_wsd,
warmup=args.warmup,
n_steps=n_steps,
decay_fraction=args.decay_fraction,
cycle_length=args.cycle_length,
min_ratio=args.lr_min_ratio,
)
else:
raise NotImplementedError(f"Unknown scheduler: {args.scheduler}")
return lr_fn
def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int):
logger.info("Starting build of optimizer...")
optimizer = AdamW(
model.parameters(),
lr=args.lr,
betas=(args.beta1, args.beta2),
weight_decay=args.weight_decay,
eps=args.epsilon,
fused=True, # Faster optim.step but can throw errors
)
# scheduler
lr_fn = build_lr_fn(args, n_steps)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_fn)
logger.info("Done with build of optimizer.")
return optimizer, scheduler

View file

@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

View file

@ -0,0 +1,3 @@
data_path: plot_data/entropy_figure.json
chart_path: figures/entropy_figure.pdf
# chart_path: figures/entropy_figure.pdf

View file

@ -0,0 +1,4 @@
df_dir: /home/par/blt_df
output_chart_dir: figures/
frame_files:
["4b_df.json", "500m_df.json", "scaling_arch_df.json", "scaling_df.json"]

View file

@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
import sys
from pathlib import Path
import altair as alt
import pandas as pd
from omegaconf import OmegaConf
from pydantic import BaseModel
class PlotEntropiesConfig(BaseModel):
data_path: str | None
chart_path: str
class Config:
extra = "forbid"
class PlotEntropiesData(BaseModel):
text: str
threshold: float = 1.335442066192627
dataframe_json: str | None
class Config:
extra = "forbid"
def main():
config_path = sys.argv[1]
file_config = OmegaConf.load(config_path)
# Omit program name and config file name
cli_conf = OmegaConf.from_cli(sys.argv[2:])
conf_dict = OmegaConf.to_container(
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
)
plot_config = PlotEntropiesConfig(**conf_dict)
with open(plot_config.data_path) as f:
json_data = f.read()
plot_data = PlotEntropiesData.model_validate_json(json_data)
df = pd.read_json(plot_data.dataframe_json)
x_ticks = []
for row in df.itertuples():
position = row.position
token = row.tokens
x_ticks.append(f"{str(position).zfill(3)}|{token}")
df["position_with_token"] = x_ticks
print(df)
x_axis = alt.Axis(
labelExpr="split(datum.label, '|')[1]",
grid=False,
labelOverlap=False,
labelAngle=0,
)
width = 1200
height = 150
base = alt.Chart(df).properties(width=width, height=height)
points = base.mark_line(point=True).encode(
x=alt.X("position_with_token:O", title=None, axis=x_axis),
y=alt.Y(
"entropies",
title="Entropy of Next Byte",
),
)
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
y=alt.datum(plot_data.threshold),
)
patch_rules = (
alt.Chart(df[df["start"] > 0])
.properties(width=width, height=height)
.mark_rule(color="#474747", strokeDash=[4, 2])
.encode(x=alt.X("position_with_token:O", axis=x_axis))
)
chart = patch_rules + rule + points
chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)
path = Path(plot_config.chart_path)
path.parent.mkdir(exist_ok=True)
chart.save(path)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import sys
from pathlib import Path
import altair as alt
import pandas as pd
import pydantic
from omegaconf import OmegaConf
class ScalingPlotsConfig(pydantic.BaseModel):
df_dir: str
output_chart_dir: str
frame_files: list[str]
class Config:
extra = "forbid"
def determine_family(key: str):
if key.startswith("Megabyte++"):
return "Megabyte++"
elif key.startswith("BLT"):
return "BLT"
elif key.startswith("LLaMA"):
return "LLaMA"
elif key.startswith("Space"):
return "Space"
file_to_vars = {}
def create_chart(df: pd.DataFrame, output_file: str):
df["metric"] = df["bpb/not_heldout.jsonl"]
df["family"] = df["key"].map(determine_family)
model_domain = [
"BLT Space ps=6",
"BLT Space w/o cross-attn",
"SpaceByte",
"LLaMA 3 BPE",
"Megabyte++ ps=4",
"Megabyte++ ps=6",
]
color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"]
shape_range = [
"circle",
"square",
"cross",
"diamond",
"triangle-up",
"triangle-down",
]
color_scale = alt.Scale(domain=model_domain, range=color_range)
shape_scale = alt.Scale(
domain=model_domain,
range=shape_range,
)
base_chart = alt.Chart(df).encode(
x=alt.X("flops", title="Training FLOPS")
.scale(type="log", domain=[2e20, 1.25e22])
.axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]),
y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False),
)
lines = base_chart.encode(
color=alt.Color("key", title="Model Color", scale=color_scale, legend=None),
strokeDash=alt.StrokeDash("family", title="Model Family", legend=None),
).mark_line()
points = base_chart.encode(
color=alt.Color("key", title="Model", scale=color_scale),
shape=alt.Shape("key", title="", scale=shape_scale),
).mark_point(size=70)
chart = (
(lines + points)
.resolve_scale(
color="independent",
shape="independent",
# strokeDash="independent",
)
.configure_legend(orient="right")
.properties(height=300, width=400)
)
print("Saving", output_file)
chart.save(output_file)
def main():
config_path = sys.argv[1]
file_config = OmegaConf.load(config_path)
# Omit program name and config file name
cli_conf = OmegaConf.from_cli(sys.argv[2:])
conf_dict = OmegaConf.to_container(
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
)
plot_config = ScalingPlotsConfig(**conf_dict)
df_dir = Path(plot_config.df_dir)
chart_dir = Path(plot_config.output_chart_dir)
chart_dir.mkdir(exist_ok=True, parents=True)
for ff in plot_config.frame_files:
path = df_dir / ff
df = pd.read_json(path)
print(df)
print(df.columns)
create_chart(df, chart_dir / f"{path.name}.pdf")
if __name__ == "__main__":
main()

View file

@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import subprocess
from pathlib import Path
import luigi
# CHANGEME: Change this to point to your data
BASE_DIR = Path("datasets")
DATASETS = ["dclm"]
TARGET_DIR = Path("entropy_preprocess")
SHARD_SCRIPT = """split -C 2500m -d {source} {destination}.shard_"""
def list_dataset_shards(dataset: str):
dataset_dir = BASE_DIR / dataset
return list(dataset_dir.glob("*.chunk.*.jsonl"))
class ChunkFile(luigi.ExternalTask):
file = luigi.Parameter()
def output(self):
return luigi.LocalTarget(self.file)
class ShardDatasetChunk(luigi.Task):
dataset_name = luigi.Parameter()
chunk_file = luigi.Parameter()
def _chunk_filename(self):
return Path(self.chunk_file).name
def requires(self):
return ChunkFile(self.chunk_file)
def run(self):
destination_dir = TARGET_DIR / str(self.dataset_name)
destination_dir.mkdir(parents=True, exist_ok=True)
destination = destination_dir / self._chunk_filename()
subprocess.check_output(
SHARD_SCRIPT.format(source=str(self.chunk_file), destination=destination),
shell=True,
)
(
Path(TARGET_DIR)
/ str(self.dataset_name)
/ f"{self._chunk_filename()}.shard.COMPLETE"
).touch()
def output(self):
return luigi.LocalTarget(
TARGET_DIR
/ str(self.dataset_name)
/ f"{self._chunk_filename()}.shard.COMPLETE"
)
class ShardDataset(luigi.WrapperTask):
dataset_name = luigi.Parameter()
def requires(self):
for f in list_dataset_shards(self.dataset_name):
yield ShardDatasetChunk(dataset_name=self.dataset_name, chunk_file=str(f))
class ShardAllDatasets(luigi.WrapperTask):
def requires(self):
for d in DATASETS:
yield ShardDataset(dataset_name=d)
if __name__ == "__main__":
luigi.build([ShardAllDatasets()], local_scheduler=True, workers=128)

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import subprocess
from pathlib import Path
import submitit
import typer
class PreprocessEntropiesJob(submitit.helpers.Checkpointable):
def __init__(self) -> None:
pass
def __call__(self, shard_file: str, output_filename: str):
subprocess.run(
[
"python",
"-u",
"-m",
"bytelatent.preprocess.preprocess_entropies",
str(shard_file),
str(output_filename),
],
check=True,
)
return True
def chunk(items, size):
for i in range(0, len(items), size):
yield items[i : i + size]
def main(
job_folder: str,
input_dir: str,
output_dir: str,
qos: str = "explore",
slurm_batch_size: int = 1000,
check_only: bool = False,
wait: bool = False,
):
input_dir = Path(input_dir)
output_dir = Path(output_dir)
shard_files = [
p for p in input_dir.glob("*.jsonl.shard*") if "COMPLETE" not in p.name
]
if check_only:
exist = []
missing = []
for shard_file in shard_files:
shard_file = Path(shard_file)
complete_file = output_dir / f"{shard_file.name}.arrow.complete"
if complete_file.exists():
exist.append(complete_file)
else:
missing.append(complete_file)
print("Checked for output files for input_dir=", input_dir)
print("Exist:", len(exist))
print("Missing:", len(missing))
print(missing)
return
print("Running parallel job over N files=", len(shard_files))
print("Input Directory:", input_dir)
print("Output Directory:", output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
executor = submitit.SlurmExecutor(job_folder)
executor.update_parameters(
# 12 hours in minutes
time=60 * 12,
qos=qos,
exclusive="user",
cpus_per_task=4,
num_gpus=1,
mem_per_gpu="80G",
array_parallelism=slurm_batch_size,
)
jobs = []
n_batches = 0
n_skipped = 0
n_launched = 0
for file_batch in chunk(shard_files, slurm_batch_size):
with executor.batch():
for shard_file in file_batch:
output_filename = Path(output_dir) / f"{shard_file.name}.arrow"
complete_output_filename = (
Path(output_dir) / f"{shard_file.name}.arrow.complete"
)
if complete_output_filename.exists():
n_skipped += 1
else:
job = executor.submit(
PreprocessEntropiesJob(), str(shard_file), str(output_filename)
)
n_launched += 1
jobs.append(job)
n_batches += 1
print("launched array jobs n=", n_launched)
print("skipped (completed) array jobs n=", n_skipped)
print("number of slurm batches=", n_batches)
if wait:
output = [job.result() for job in jobs]
assert all(output)
if __name__ == "__main__":
typer.run(main)

View file

@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import time
from pathlib import Path
import numpy as np
import pyarrow as pa
import torch
import typer
from rich.progress import Progress, TextColumn
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
def main(
input_file: str,
output_file: str,
patching_device: str = "cuda",
log_step: int = 10_000,
entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
dry_run: bool = False,
):
# TODO: Modify this to work with the new code
raise NotImplementedError()
iterator = ArrowFileIterator(
file_path=input_file,
worker_id=0,
num_workers=1,
)
tokenization_mode = "bytes"
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
print("Loading entropy model", entropy_model_checkpoint_dir)
if dry_run:
return
entropy_model = load_entropy_model(
entropy_model_checkpoint_dir, device=patching_device
)
entropy_model, _ = to_device(entropy_model, patching_device)
print("Creating patcher")
patching_batch_size = 32
print("Creating tokenizer")
tokenizer = Tokenizer(
model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
tokenization_mode=tokenization_mode,
# BYTE_UNITS
vocab_size_unit_1=256,
bos=True,
eos=True,
bpe_delim=False,
# This isn't used, just stores a reference for other calls we don't use
patcher=None,
)
step = 0
print("starting")
start_time = time.time()
patch_time = 0
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
text_field = pa.field("text", pa.string(), nullable=False)
schema = pa.schema([sample_id_field, text_field, entropy_field])
arrow_batch_size = 1_000
try:
with pa.OSFile(output_file, "wb") as sink:
with pa.ipc.new_file(sink, schema) as writer:
id_buffer = []
entropies_buffer = []
text_buffer = []
with Progress(
*Progress.get_default_columns(),
TextColumn("Completed: {task.completed}"),
) as progress:
task = progress.add_task(
"[green]Calculating entropies...", total=None
)
for doc in iterator:
sample_id = get_id_from_doc(doc)
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(
f"Could not find a text key from: {doc.keys()}"
)
tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time()
scores = calculate_entropies(
tokens,
entropy_model,
patching_batch_size,
patching_device,
)
entropies_buffer.append(
np.array(scores.tolist(), dtype=np.float16)
)
id_buffer.append(sample_id)
text_buffer.append(text)
if len(entropies_buffer) == arrow_batch_size:
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
patch_time += time.time() - patch_start
step += 1
if step % log_step == 0:
print("Completed steps:", step)
progress.update(task, advance=1)
if len(entropies_buffer) > 0:
# Write last things
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
Path(f"{output_file}.complete").touch()
except:
Path(output_file).unlink(missing_ok=True)
raise
elapsed = time.time() - start_time
print("steps", step)
print("done in:", elapsed)
if __name__ == "__main__":
typer.run(main)

694
bytelatent/probe.py Normal file
View file

@ -0,0 +1,694 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This file from the xFormers repo is just a example of how to implement
# probing of the activations of a model, without changing anything.
# By default, the linear inputs/outputs/gradients are logged, as well as
# the attention logits+entropy. It is possible to log an additional tensor, eg:
# x = log_stats(x, "name")
#
# Known limitations:
# * Only a subset of the attention biases is supported
# * Torch-compile is disabled automatically when this is enabled
# * Only tested with bf16/f16/f32 datatypes
import contextlib
import functools
import json
import math
import os
import uuid
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
checkpoint_wrapper,
)
from torch.fx.operator_schemas import normalize_function
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torch.utils.module_tracker import ModuleTracker
from xformers.ops import fmha
@torch.library.custom_op("torchprobe::log", mutates_args=(), device_types=None)
def _log(x: torch.Tensor, name: str, uid: str) -> None:
pass
@_log.register_fake
def _log_fake(x: torch.Tensor, name: str, uid: str) -> None:
pass
class _LogStats(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, name: str):
uid = str(uuid.uuid4())
torch.ops.torchprobe.log(x, name, uid)
ctx.name = name
ctx.uid = uid
return x
@staticmethod
def backward(ctx, grad: torch.Tensor):
torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid)
return grad, None
_PROBING_ENABLED = False
def log_stats(x: torch.Tensor, name: str) -> torch.Tensor:
if not _PROBING_ENABLED:
return x
return _LogStats.apply(x, name)
QUANTILES = [
0.0000001,
0.000001,
0.00001,
0.0001,
0.001,
0.01,
0.05,
0.1,
0.3,
0.5,
0.7,
0.9,
0.95,
0.99,
0.999,
0.9999,
0.99999,
0.999999,
0.9999999,
]
@functools.cache
def _get_quantiles(device: torch.device, dtype) -> torch.Tensor:
return torch.tensor(QUANTILES, device=device, dtype=dtype)
def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]:
if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]:
return {}
x = x_.flatten()
if remove_inf:
x = x[x.abs() < float("inf")]
if x.dtype is not torch.double:
x = x.float()
xabs = x.abs()
quantiles = _get_quantiles(x.device, x.dtype)
mean = x.mean()
std = x.std()
return {
"shape": tuple(x_.shape),
"mean": mean,
"std": std,
"skew": (((x - mean) / std) ** 3).double().mean(),
"kurtosis": (((x - mean) / std) ** 4).double().mean(),
"abs.mean": xabs.mean(),
"max": x.max(),
"min": x.min(),
# Note: `quantile` takes at most 2**24 elements, see
# https://github.com/pytorch/pytorch/issues/64947
"quantiles": torch.quantile(x[: 2**24], quantiles),
}
def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None:
assert logits.ndim == 4
logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf
def _mask_attn_logits(
logits: torch.Tensor,
q_idx: List[int],
*,
causal: bool,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert logits.dtype is torch.float32
# Handle BlockDiagonalMask
if cu_seqlens_q is not None:
assert cu_seqlens_k is not None
# Expect BHMqMkv
assert logits.ndim == 4, logits.shape
qs = cu_seqlens_q.tolist()
ks = cu_seqlens_k.tolist()
q_batchid = []
k_batchid = [-2] * logits.shape[-1]
q_idx_i = 0
for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])):
for k in range(k0, k1):
k_batchid[k] = bid
while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1:
q_batchid.append(bid)
if causal:
_mask_attn_causal_inplace(
logits[:, :, q_idx_i : q_idx_i + 1, k0:k1],
q_idx[q_idx_i] - q0,
q1 - q0,
k1 - k0,
)
q_idx_i += 1
mask_out = (
torch.tensor(q_batchid, device=logits.device)[None, None, :, None]
!= torch.tensor(k_batchid, device=logits.device)[None, None, None, :]
)
logits[mask_out.expand_as(logits)] = -math.inf
assert q_idx_i == len(q_idx)
elif causal:
for q_idx_i in range(len(q_idx)):
_mask_attn_causal_inplace(
logits[:, :, q_idx_i : q_idx_i + 1, :],
q_idx[q_idx_i],
logits.shape[2],
logits.shape[3],
)
return logits
def _attn_queries_subset(num_queries: int) -> List[int]:
return list(range(0, num_queries, max(1, num_queries // 128)))
@torch.no_grad()
def _compute_attn_stats_sdpa(
probe,
path: str,
# supports arguments both cudnn + flash backends
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask=None,
attn_bias=None,
dropout_p=0.0,
is_causal=False,
scale=None,
compute_log_sumexp=True,
return_debug_mask=False,
**kwargs,
):
if scale is None:
scale = 1 / (query.shape[-1] ** 0.5)
# Filter-out not supported cases
if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs:
probe.store[f"{path}::attn"] = {
"query.shape": tuple(query.shape),
"key.shape": tuple(key.shape),
"value.shape": tuple(value.shape),
"attn_mask": attn_mask.shape if attn_mask is not None else None,
"dropout_p": dropout_p,
"is_causal": is_causal,
"scale": scale,
"unk_kwargs": list(kwargs.keys()),
}
return
# Take a subset of the queries and compute the logits
query_s = _attn_queries_subset(query.shape[-2])
logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale
logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal)
p = logits.float().softmax(-1)
masked_logsoft = logits.log_softmax(-1).where(
(logits > -math.inf), torch.zeros_like(logits)
)
entropy = -(p * masked_logsoft).sum(-1)
probe.log_tensor(f"{path}::attn_entropy", entropy)
probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True)
@torch.no_grad()
def _compute_attn_stats_flash(
probe,
path: str,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor],
cu_seqlens_k: Optional[torch.Tensor],
seqused_k: Optional[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k: int,
p: float,
softmax_scale: float,
is_causal: bool,
window_left: int,
window_right: int,
return_softmax: bool,
block_tables: Optional[torch.Tensor],
unpadded_lse: bool = False,
) -> None:
# Filter-out not supported cases
if (
seqused_k is not None
or p != 0.0
or window_left >= 0
or window_right >= 0
or block_tables is not None
):
probe.store[f"{path}::attn"] = {
"query.shape": tuple(query.shape),
"key.shape": tuple(key.shape),
"value.shape": tuple(value.shape),
"op": "flash",
}
return
if cu_seqlens_q is not None:
assert query.ndim == 3, query.shape
query, key, value = query[None], key[None], value[None]
assert query.ndim == 4, query.shape
# Take a subset of the queries and compute the logits
query_s = _attn_queries_subset(query.shape[1])
logits = (
query[:, query_s].transpose(1, 2)
@ key.transpose(1, 2).transpose(-1, -2)
* softmax_scale
)
logits = _mask_attn_logits(
logits.float(),
query_s,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
causal=is_causal,
)
p = logits.float().softmax(-1)
masked_logsoft = logits.log_softmax(-1).where(
(logits > -math.inf), torch.zeros_like(logits)
)
entropy = -(p * masked_logsoft).sum(-1)
probe.log_tensor(f"{path}::attn_entropy", entropy)
probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True)
def _tensors_to_python(x):
if not isinstance(x, torch.Tensor):
return x
return x.tolist()
# class syntax
class LinearBwType(Enum):
DW = 1
DX = 2
UNKNOWN = 3
class AutoProbeD(TorchDispatchMode):
def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None:
self.write_file = Path(write_file) if write_file is not None else None
self.write_tensors_tmpdir: Optional[Path] = None
self.compile_disabler = TorchCompileDisabler(module)
self.mod_tracker = ModuleTracker()
self.count_per_path: Dict[str, int] = defaultdict(int)
self.store: Dict[str, Dict[str, Any]] = {}
self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {}
self.uid_to_path: Dict[str, str] = {}
self.metadata: Any = None
self.enabled = False
self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0")))
def __enter__(self):
global _PROBING_ENABLED
assert not self.enabled, "Entered probe twice"
self.compile_disabler.__enter__()
self.mod_tracker.__enter__()
super().__enter__()
self.enabled = True
_PROBING_ENABLED = True
# self._setup_tensors_logging()
return self
def __exit__(self, *args) -> None:
global _PROBING_ENABLED
assert self.enabled, "Exiting probe without entering it"
super().__exit__(*args)
self.mod_tracker.__exit__(*args)
self.compile_disabler.__exit__(*args)
self._flush_and_clear()
_PROBING_ENABLED = False
self.enabled = False
def _setup_tensors_logging(self):
if self.write_file is not None:
self.write_file.parent.mkdir(exist_ok=True)
self.write_tensors_tmpdir = (
self.write_file.parent
/ f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}"
)
self.write_tensors_tmpdir.mkdir(exist_ok=True)
def _flush_and_clear(self) -> None:
if self.write_file is not None:
dump_data = tree_map(_tensors_to_python, self.store)
with self.write_file.open("a") as fd:
json.dump(
{
"data": dump_data,
"meta": self.metadata,
"version": 2,
"quantiles": QUANTILES,
},
fd,
)
fd.write("\n")
if self.write_tensors_tmpdir is not None:
assert self.write_file is not None
dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump"
dump_dir.mkdir(exist_ok=True)
dir_name = ""
if "it" in self.metadata:
dir_name = f"it{int(self.metadata['it']):010}"
if dir_name == "" or (dump_dir / dir_name).exists():
num_files = len(list(dump_dir.glob(f"{dir_name}v*")))
dir_name = f"{dir_name}v{num_files}"
dump_dir = dump_dir / dir_name
assert not dump_dir.exists()
self.write_tensors_tmpdir.rename(dump_dir)
self.write_tensors_tmpdir = None
self.store.clear()
self.count_per_path.clear()
self.uid_to_path.clear()
def _find_bw_path_and_type(
self, path: str, out: torch.Tensor, args
) -> Tuple[str, LinearBwType]:
"""
We are in the BW pass, and process a GEMM.
Let's figure out:
(1) The path for the FW pass (might differ in case of ModuleTracker bug)
(2) The type of BW pass (eg `dw` or `dx`)
"""
def _is_path_correct_dw(path: str) -> bool:
# dW.t = dY.t @ X
in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path]
return out.shape == (w_shape[1], w_shape[0]) and torch.allclose(
input_sm, args[1][:4, :4]
)
def _is_path_correct_dx(path: str) -> bool:
# dX = dY @ W.t
in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path]
return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4])
if path in self.linear_data:
if _is_path_correct_dw(path):
return path, LinearBwType.DW
if _is_path_correct_dx(path):
return path, LinearBwType.DX
for candidate_path in self.mod_tracker.parents:
if candidate_path not in self.linear_data:
continue
if _is_path_correct_dw(candidate_path):
return candidate_path, LinearBwType.DW
if _is_path_correct_dx(candidate_path):
return candidate_path, LinearBwType.DX
return path, LinearBwType.UNKNOWN
def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None:
self.store[name] = _get_stats(x, **kwargs)
if self.write_tensors_tmpdir is not None:
name_safe = name.replace("::", "__").replace("/", "")
torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl")
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
path = None
# Find longest path
for p in self.mod_tracker.parents:
if p == "Global":
continue
if path is None or len(p) > len(path):
path = p
if path is None:
path = "Global"
path = path.replace("._checkpoint_wrapped_module", "")
out = func(*args, **kwargs)
# Handle linear layers
if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]:
weight: torch.Tensor
input: torch.Tensor
if not self.mod_tracker.is_bw:
# (technically, weight is transposed)
if func._overloadpacket == torch.ops.aten.addmm:
_bias, input, weight = args[:3]
else:
assert func._overloadpacket == torch.ops.aten.mm
input, weight = args[:2]
self.log_tensor(f"{path}::in", input)
self.log_tensor(f"{path}::w", weight)
self.log_tensor(f"{path}::out", out)
self.linear_data[path] = (
input.shape,
weight.shape,
out.shape,
input[:4, :4].clone(),
weight[:4, :4].T.clone(),
)
elif func._overloadpacket == torch.ops.aten.mm:
# XXX: Try to find the actual path for the linear layer
# This is messed with with Francisco's FSDP sometimes
new_path, bwtype = self._find_bw_path_and_type(path, out, args)
if new_path != path:
if self.verbose:
print(f"E: Fixing path `{path}` -> `{new_path}")
path = new_path
if bwtype == LinearBwType.DW:
# dW.t = dY.t @ X
self.log_tensor(f"{path}::w.g", out)
elif bwtype == LinearBwType.DX:
# dX = dY @ W.t
self.log_tensor(f"{path}::in.g", out)
self.log_tensor(f"{path}::out.g", args[0])
elif func._overloadpacket in [
torch.ops.aten._scaled_dot_product_flash_attention,
torch.ops.aten._scaled_dot_product_cudnn_attention,
]:
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
_compute_attn_stats_sdpa(self, path, **kwargs)
elif func._overloadpacket == fmha.flash.FwOp.OPERATOR:
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
_compute_attn_stats_flash(self, path, **kwargs)
elif func._overloadpacket == torch.ops.torchprobe.log:
uid = args[2]
path = self.uid_to_path.setdefault(uid, path)
self.log_tensor(f"{path}::{args[1]}", args[0])
if self.verbose:
print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}")
return out
def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None:
if module._compiled_call_impl is not None:
out.append(module)
for c in module.children():
_find_all_submodules_compiled(out, module=c)
class TorchCompileDisabler:
def __init__(self, module: nn.Module) -> None:
self.module = module
self.submodules_compiled: List[nn.Module] = []
self.compiled_call_impl: List[Any] = []
self.disable_compile = torch.compiler.disable()
torch._dynamo.config.raise_on_ctx_manager_usage = False # type: ignore
def __enter__(self) -> None:
# Remove all `_compiled_call_impl` attributes to effectively
# "undo" compilation
self.submodules_compiled.clear()
_find_all_submodules_compiled(self.submodules_compiled, self.module)
self.compiled_call_impl = [
m._compiled_call_impl for m in self.submodules_compiled
]
for m in self.submodules_compiled:
m._compiled_call_impl = None
self.disable_compile.__enter__() # type: ignore
def __exit__(self, *args) -> None:
self.disable_compile.__exit__(*args) # type: ignore
for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl):
m._compiled_call_impl = c_impl
self.compiled_call_impl = []
Probe = AutoProbeD
# EXAMPLE USAGE
d = 512
seqlen = 4
bs = 2
class Attention1(nn.Module):
def forward(self, x):
attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask()
return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape(
[x.shape[0], seqlen, -1]
)
class Attention2(nn.Module):
def forward(self, x):
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[seqlen] * bs
).make_causal()
xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]])
return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape(
[x.shape[0], seqlen, -1]
)
class AttentionSDPA(nn.Module):
def __init__(self):
super().__init__()
self.wo = nn.Linear(d, d)
def forward(self, x):
x = x.transpose(1, 2)
return self.wo(
F.scaled_dot_product_attention(x, x, x)
.transpose(1, 2)
.reshape([x.shape[0], seqlen, -1])
)
class AttentionSDPAFlash(AttentionSDPA):
def forward(self, x):
x = x.transpose(1, 2)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
return self.wo(
F.scaled_dot_product_attention(x, x, x)
.transpose(1, 2)
.reshape([x.shape[0], seqlen, -1])
)
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.head = nn.Linear(d, 16)
self.trunk = nn.Sequential(
nn.Linear(d, d),
nn.Linear(d, d),
)
self.q_proj = nn.Linear(d, d, bias=False)
self.trunk.compile()
self.attn1 = Attention1()
self.attn2 = Attention2()
self.attnSDPA = AttentionSDPA()
self.attnSDPAflash = AttentionSDPAFlash()
def forward(self, x):
B, nHeads, D = x.shape[0], d // 64, 64
x = self.q_proj(x).reshape([B, seqlen, nHeads, D])
x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x)
x = log_stats(x, "attns_out")
return self.head(self.trunk(x))
def test_masking() -> None:
q_seqlen = [1, 1, 14, 12]
kv_seqlen = [2, 2, 14, 18]
attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen, kv_seqlen
).make_causal_from_bottomright()
logits = torch.randn(
[1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda"
)
bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device)
logits_masked = logits.clone()
_mask_attn_logits(
logits_masked,
list(range(logits.shape[2])),
causal=True,
cu_seqlens_q=attn_bias.q_seqinfo.seqstart,
cu_seqlens_k=attn_bias.k_seqinfo.seqstart,
)
assert (logits + bias == logits_masked).all().item()
def test_toy_model() -> None:
# Test masking
kw = dict(device="cuda", dtype=torch.float16)
x = torch.randn([bs, seqlen, d], **kw)
m = Model()
m.head = checkpoint_wrapper(
m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
)
m.to(**kw)
m.compile()
optim = torch.optim.SGD(m.parameters(), lr=0.0)
probe = AutoProbeD(m, "./probe.json")
for i in range(4):
with contextlib.ExitStack() as stack:
print(f"########### STEP {i}")
if i % 4 == 1:
stack.enter_context(probe)
probe.metadata = {"it": i}
y = m(x)
g = torch.randn_like(y)
y.backward(g)
if i % 4 == 1:
assert probe.enabled
# Make sure we registered all linears
print(list(probe.store.keys()))
for key in [
"Model::attns_out",
"Model::attns_out.g",
"Model.attn1::attn_logits",
"Model.attn2::attn_logits",
"Model.attnSDPA::attn_logits",
"Model.attnSDPAflash::attn_logits",
"Model.head::w",
"Model.head::w.g",
"Model.head::in",
"Model.head::in.g",
"Model.head::out",
"Model.head::out.g",
"Model.trunk.0::in",
"Model.trunk.1::in",
]:
assert key in probe.store, f"Missing key: '{key}'"
# .. and that the values are correct
for key, tensor in [
("Model.head::w", m.head.weight),
("Model.head::w.g", m.head.weight.grad),
("Model.q_proj::in", x),
("Model.q_proj::w.g", m.q_proj.weight.grad),
("Model.head::out", y),
("Model.head::out.g", g),
]:
assert key in probe.store, f"Missing key: '{key}'"
assert torch.allclose(
probe.store[key]["abs.mean"], tensor.float().abs().mean()
), f"'{key}' mismatches"
# Check we don't have `nans`
for key, value in probe.store.items():
if "abs.mean" in value:
assert math.isfinite(
value["abs.mean"].item()
), f"Inf/Nan for {key}"
optim.step()
optim.zero_grad()

133
bytelatent/profiling.py Normal file
View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import contextlib
import logging
import os
from pathlib import Path
import torch.distributed
import wandb
import xformers.profiler
from pydantic import BaseModel
from torch.profiler.profiler import profile
from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler
from bytelatent.distributed import get_is_master
class ProfilerArgs(BaseModel):
run: bool = False
trace_folder: str = "profiling"
mem_warmup: int = 100
mem_steps: int = 2
profile_warmup: int = 102
profile_steps: int = 2
logger = logging.getLogger()
def perfetto_to_html(json_file, html_file):
import gzip
import string
import viztracer
root = os.path.dirname(viztracer.__file__)
sub = {}
json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file)
with open(
os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8"
) as f:
tmpl = f.read()
with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f:
sub["trace_viewer_full"] = f.read()
with json_file as j:
content = j.read()
if isinstance(content, bytes):
content = content.decode("utf-8")
sub["json_data"] = content.replace("</script>", "<\\/script>") # type: ignore
with open(html_file, "w+", encoding="utf-8") as output_file:
output_file.write(string.Template(tmpl).substitute(sub))
class PyTorchProfilerWandb(PyTorchProfiler):
def __init__(self, main_profiler) -> None:
self.main_profiler = main_profiler
self.num_steps = 0
self.pytorch_profiler = torch.profiler.profile(
on_trace_ready=self._on_trace,
profile_memory=True,
record_shapes=True,
# With stack gives huge profile traces
# and bugs out because of some non ascii
# character somewhere in pytorch
with_stack=False,
with_flops=True,
activities=self.ACTIVITIES,
)
def _analyze_trace(self, prof: profile):
logger.info("Begin analyze trace")
super()._analyze_trace(prof)
logger.info("End analyze trace")
def _on_trace(self, prof: torch.profiler.profiler.profile) -> None:
super()._on_trace(prof)
if get_is_master() and wandb.run is not None:
filename = list(
Path(self.main_profiler.output_dir).glob(
"profile_CPU_CUDA*/*.pt.trace.json*"
)
)[0]
html_path = str(filename).replace(".json", ".html")
perfetto_to_html(filename, html_path)
wandb.log({"profile_trace": wandb.Html(html_path)})
class MemSnapshotsProfilerWandb(MemSnapshotsProfiler):
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
if get_is_master() and wandb.run is not None:
filename = list(
Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html")
)[0]
wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)})
@contextlib.contextmanager
def maybe_run_profiler(dump_dir, module, config: ProfilerArgs):
# get user defined profiler settings
if config.run:
trace_dir = os.path.join(dump_dir, config.trace_folder)
logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
if get_is_master() and not os.path.exists(trace_dir):
os.makedirs(trace_dir)
if torch.distributed.is_initialized():
torch.distributed.barrier()
with xformers.profiler.profile(
output_dir=trace_dir,
module=module,
schedule=[
(
MemSnapshotsProfilerWandb,
config.mem_warmup,
config.mem_warmup + config.mem_steps,
),
(
PyTorchProfilerWandb,
config.profile_warmup,
config.profile_warmup + config.profile_steps,
),
],
) as profiler:
yield profiler
else:
torch_profiler = contextlib.nullcontext()
yield None

237
bytelatent/stool.py Normal file
View file

@ -0,0 +1,237 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import os
import shutil
import subprocess
from dataclasses import dataclass
from typing import Any, Dict
from omegaconf import OmegaConf
@dataclass
class StoolArgs:
config: Any = None
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
script: str = "apps.main.train" # The script to run.
copy_code: bool = True # Wether to copy code to dump dir
dirs_exists_ok: bool = (
False # Wether to copy new code and config and run regardless that dir exists
)
override: bool = False # Wether to delete dump dir and restart
nodes: int = -1 # The number of nodes to run the job on.
ngpu: int = 8 # The number of GPUs required per node.
ncpu: int = 16 # The number of CPUs allocated per GPU.
mem: str = "" # The amount of memory to allocate.
anaconda: str = "default" # The path to the anaconda environment.
constraint: str = "" # The constraint on the nodes.
exclude: str = "" # The nodes to exclude.
time: int = -1 # The time limit of the job (in minutes).
account: str = ""
qos: str = ""
partition: str = "learn"
stdout: bool = False
SBATCH_COMMAND = """#!/bin/bash
{exclude}
{qos}
{account}
{constraint}
#SBATCH --job-name={name}
#SBATCH --nodes={nodes}
#SBATCH --gres=gpu:{ngpus}
#SBATCH --cpus-per-gpu={ncpu}
#SBATCH --time={time}
#SBATCH --partition={partition}
#SBATCH --mem={mem}
#SBATCH --output={dump_dir}/logs/%j/%j.stdout
#SBATCH --error={dump_dir}/logs/%j/%j.stderr
#SBATCH --open-mode=append
#SBATCH --signal=USR2@120
#SBATCH --distribution=block
# Mimic the effect of "conda init", which doesn't work for scripts
eval "$({conda_exe} shell.bash hook)"
source activate {conda_env_path}
{go_to_code_dir}
export OMP_NUM_THREADS=1
export LAUNCH_WITH="SBATCH"
export DUMP_DIR={dump_dir}
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
"""
def copy_dir(input_dir: str, output_dir: str) -> None:
print(f"Copying : {input_dir}\n" f"to : {output_dir} ...")
assert os.path.isdir(input_dir), f"{input_dir} is not a directory"
assert os.path.isdir(output_dir), f"{output_dir} is not a directory"
rsync_cmd = (
f"rsync -arm --copy-links "
f"--include '**/' "
f"--include '*.py' "
f"--exclude='*' "
f"{input_dir}/ {output_dir}"
)
print(f"Copying command: {rsync_cmd}")
subprocess.call([rsync_cmd], shell=True)
print("Copy done.")
def retrieve_max_time_per_partition() -> Dict[str, int]:
# retrieve partition max times (a bit slow)
sinfo = json.loads(subprocess.check_output("sinfo --json", shell=True))["sinfo"]
max_times: Dict[str, int] = {}
for info in sinfo:
if info["partition"]["maximums"]["time"]["infinite"]:
max_times[info["partition"]["name"]] = 14 * 24 * 60 # 14 days
else:
max_times[info["partition"]["name"]] = info["partition"]["maximums"][
"time"
][
"number"
] # in minutes
return max_times
def validate_args(args) -> None:
# Set maximum time limit if not specified
if args.time == -1:
max_times = retrieve_max_time_per_partition()
args.time = max_times.get(
args.partition, 3 * 24 * 60
) # Default to 3 days if not found
print(
f"No time limit specified, using max time for partitions: {args.time} minutes"
)
if args.constraint:
args.constraint = f"#SBATCH --constraint={args.constraint}"
if args.account:
args.account = f"#SBATCH --account={args.account}"
if args.qos:
args.qos = f"#SBATCH --qos={args.qos}"
if getattr(args, "exclude", ""):
args.exclude = f"#SBATCH --exclude={args.exclude}"
if hasattr(args, "anaconda") and args.anaconda:
if args.anaconda == "default":
args.anaconda = (
subprocess.check_output("which python", shell=True)
.decode("ascii")
.strip()
)
else:
args.anaconda = f"{args.anaconda}/bin/python"
assert os.path.isfile(args.anaconda)
args.mem = args.mem or "0"
assert args.partition
assert args.ngpu > 0
assert args.ncpu > 0
assert args.nodes > 0
assert args.time > 0
assert args.partition
def launch_job(args: StoolArgs):
# Set up args default and validate them depending on the cluster or partition requested
validate_args(args)
dump_dir = args.config["dump_dir"]
job_name = args.config["name"]
print("Creating directories...")
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
if args.override:
confirm = input(
f"Are you sure you want to delete the directory '{dump_dir}'? This action cannot be undone. (yes/no): "
)
if confirm.lower() == "yes":
shutil.rmtree(dump_dir)
print(f"Directory '{dump_dir}' has been deleted.")
else:
print("Operation cancelled.")
return
if args.copy_code:
os.makedirs(f"{dump_dir}/code", exist_ok=args.dirs_exists_ok)
print("Copying code ...")
copy_dir(os.getcwd(), f"{dump_dir}/code")
print("Saving config file ...")
with open(f"{dump_dir}/base_config.yaml", "w") as cfg:
cfg.write(OmegaConf.to_yaml(args.config))
conda_exe = os.environ.get("CONDA_EXE", "conda")
conda_env_path = os.path.dirname(os.path.dirname(args.anaconda))
log_output = (
"-o $DUMP_DIR/logs/%j/%j_%t.out -e $DUMP_DIR/logs/%j/%j_%t.err"
if not args.stdout
else ""
)
sbatch = SBATCH_COMMAND.format(
name=job_name,
script=args.script,
dump_dir=dump_dir,
nodes=args.nodes,
tasks=args.nodes * args.ngpu,
nodes_per_run=args.nodes,
ngpus=args.ngpu,
ncpu=args.ncpu,
mem=args.mem,
qos=args.qos,
account=args.account,
constraint=args.constraint,
exclude=args.exclude,
time=args.time,
partition=args.partition,
conda_exe=conda_exe,
conda_env_path=conda_env_path,
log_output=log_output,
go_to_code_dir=f"cd {dump_dir}/code/" if args.copy_code else "",
)
print("Writing sbatch command ...")
with open(f"{dump_dir}/submit.slurm", "w") as f:
f.write(sbatch)
print("Submitting job ...")
os.system(f"{args.launcher} {dump_dir}/submit.slurm")
print("Done.")
if __name__ == "__main__":
"""
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
mode: LMTransformerArgs
@dataclass
class LMTransformerArgs:
dim: int
Then you can pass model.dim=32 to change values in LMTransformerArgs
or just name=tictac for top level attributes.
"""
raise NotImplementedError("Update this to blt code")
args = OmegaConf.from_cli()
args.config = OmegaConf.load(args.config)
args = dataclass_from_dict(StoolArgs, args)
launch_job(args)

471
bytelatent/test_blt.py Normal file
View file

@ -0,0 +1,471 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
from dataclasses import replace
import numpy as np
import pytest
import torch
from bytelatent.constants import BLT_DATA
from bytelatent.data.data_types import Batch
from bytelatent.data.ngram_processor import NgramProcessor
from bytelatent.model.blt import (
ByteLatentTransformer,
ByteLatentTransformerArgs,
EmbeddingType,
compute_hash_embeddings,
create_global_transformer,
create_local_decoder,
create_local_encoder,
cross_attn_mask,
decoder_patch_ids_from_lengths,
get_blt_input,
init_embeddings,
patch_ids_from_lengths,
)
from bytelatent.model.transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask
from bytelatent.optim import OptimArgs, build_optimizer
from bytelatent.train import compute_loss
def batch_to_tensors_and_gpu(batch):
x = torch.from_numpy(batch.x)
y = torch.from_numpy(batch.y)
mask = None if batch.mask is None else torch.from_numpy(batch.mask)
patch_lengths = (
None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths)
)
ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids)
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
if mask is not None:
mask = mask.cuda()
if patch_lengths is not None:
patch_lengths = patch_lengths.cuda()
if ngram_ids is not None:
ngram_ids = ngram_ids.cuda()
return x, y, mask, patch_lengths, ngram_ids
def fake_batch():
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"))
del batch_dict["x2"]
del batch_dict["y2"]
del batch_dict["src_names"]
return Batch(**batch_dict)
def create_args(cross_attention=False):
transformer_args = ByteLatentTransformerArgs(
# Base args provided
n_heads=8,
dim=512,
vocab_size=260,
# Additional args from command line
dim_token=256,
patch_size=6,
tokenization_mode="bytes",
patching_mode="space",
tie_local_encoder_decoder_logits=False,
data_loader_patching=True,
max_encoder_seq_length=12288,
pad_to_max_length=True,
encoder_lm_loss=False,
patching_threshold=3.1439168453216553,
encoder_hash_byte_group_size=[4],
encoder_hash_byte_group_vocab=50002,
encoder_hash_byte_group_nb_functions=3,
cross_attn_encoder=cross_attention, # True,
cross_attn_decoder=cross_attention, # True,
cross_attn_window_encoder=512,
cross_attn_window_decoder=512,
dim_local_encoder=256,
dim_local_decoder=256,
cross_attn_k=8,
cross_attn_nheads=4,
cross_attn_all_layers_decoder=True,
cross_attn_all_layers_encoder=True,
cross_attn_use_flex_attention=True,
cross_attn_init_by_pooling=True,
log_patch_lengths=True,
non_linearity="swiglu",
use_rope=True,
recompute_fc1_out=False,
recompute_fc3_out=False,
recompute_attn=False,
custom_bwd=False,
layer_ckpt="none",
efficient_attn="sdpa",
patch_only_encoder=False,
patch_only_decoder=False,
use_local_encoder_transformer=True,
init_use_gaussian=True,
init_use_depth="current",
attn_bias_type="block_causal",
alpha_depth="disabled",
max_length=256,
local_attention_window_len=512,
max_seqlen=12288,
downsampling_by_pooling="max",
)
return transformer_args
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
class TestByteLatentTransformer:
def test_local_encoder(self):
args = create_args()
device = torch.device("cuda")
local_encoder = create_local_encoder(args).to(device)
batch = fake_batch()
tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
local_encoder_tokens, _, _ = get_blt_input(
tokens=tokens,
enforce_patch_size_multiple=False,
nb_boe=0,
patch_size=local_encoder.patch_size,
boe_id=local_encoder.boe_id,
)
patch_ids = patch_ids_from_lengths(
patch_lengths, local_encoder_tokens.shape[-1]
)
encoder_hash_tok_embedding = init_embeddings(
args,
EmbeddingType.HASH_TOK,
local_encoder_dim=local_encoder.dim,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
).to(device)
local_encoder_embeds = compute_hash_embeddings(
local_encoder_tokens=local_encoder_tokens,
local_encoder=local_encoder,
encoder_hash_tok_embedding=encoder_hash_tok_embedding,
encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
)
reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt")
reference_tokens = torch.load(reference_path).to(device)
torch.testing.assert_close(
local_encoder_tokens,
reference_tokens,
msg="Generated tokens don't match reference tokens",
)
(h_encoder, h_cross), cache_encoder = local_encoder(
tokens=local_encoder_tokens,
embeds=local_encoder_embeds,
patch_embeds=None,
cross_mask=None,
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
)
assert h_encoder is not None
assert h_cross is None
assert cache_encoder is None
expected_shape = (
local_encoder_tokens.shape[0],
local_encoder_tokens.shape[1],
local_encoder.dim,
)
assert h_encoder.shape == expected_shape
def test_local_encoder_cross_attention(self):
args = create_args(cross_attention=True)
device = torch.device("cuda")
local_encoder = create_local_encoder(args).to(device)
batch = fake_batch()
tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
local_encoder_tokens, _, _ = get_blt_input(
tokens=tokens,
enforce_patch_size_multiple=False,
nb_boe=0,
patch_size=local_encoder.patch_size,
boe_id=local_encoder.boe_id,
)
patch_ids = patch_ids_from_lengths(
patch_lengths, local_encoder_tokens.shape[-1]
)
encoder_hash_tok_embedding = init_embeddings(
args,
EmbeddingType.HASH_TOK,
local_encoder_dim=local_encoder.dim,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
).to(device)
cross_attn_mask_enc = cross_attn_mask(
patch_ids,
patch_lengths,
local_encoder_tokens.shape[-1],
patches_as_queries=True,
cross_attn_k=args.cross_attn_k,
window=args.cross_attn_window_encoder,
block_mask=True,
)
local_encoder_embeds = compute_hash_embeddings(
local_encoder_tokens=local_encoder_tokens,
local_encoder=local_encoder,
encoder_hash_tok_embedding=encoder_hash_tok_embedding,
encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
)
(h_encoder, h_cross), cache_encoder = local_encoder(
tokens=local_encoder_tokens,
embeds=local_encoder_embeds,
patch_embeds=None,
cross_mask=cross_attn_mask_enc,
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
)
assert h_encoder is not None
assert h_cross is not None
assert cache_encoder is None
expected_shape = (
local_encoder_tokens.shape[0],
local_encoder_tokens.shape[1],
local_encoder.dim,
)
assert h_encoder.shape == expected_shape
assert h_cross.shape == (2, 2048, local_encoder.dim)
def test_local_decoder_cross_attention(self):
args = create_args(cross_attention=True)
device = torch.device("cuda")
local_decoder = create_local_decoder(args).to(device)
test_files = {
"dec_embeds": "dec_embeds.pt",
"decoder_tokens": "local_decoder_tokens.pt",
"patch_embeds": "decoder_patch_cross_embeds.pt",
}
batch = fake_batch()
_, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)
tensors = {
name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
for name, filename in test_files.items()
}
decoder_patch_ids = decoder_patch_ids_from_lengths(
patch_lengths, 0, tensors["decoder_tokens"].shape[-1]
)
cross_attn_mask_dec = cross_attn_mask(
decoder_patch_ids,
patch_lengths,
tensors["decoder_tokens"].shape[-1],
patches_as_queries=False,
cross_attn_k=args.cross_attn_k,
window=args.cross_attn_window_decoder,
block_mask=True,
)
output, _ = local_decoder(
embeds=tensors["dec_embeds"],
patch_embeds=tensors["patch_embeds"],
tokens=tensors["decoder_tokens"],
cross_mask=cross_attn_mask_dec,
cache=None,
)
assert output is not None
assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size)
def test_local_decoder(self):
args = create_args()
device = torch.device("cuda")
local_decoder = create_local_decoder(args).to(device)
test_files = {
"dec_embeds": "dec_embeds.pt",
"decoder_tokens": "local_decoder_tokens.pt",
"patch_embeds": "decoder_patch_embeds.pt",
}
tensors = {
name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
for name, filename in test_files.items()
}
output, cache_decoder = local_decoder(
embeds=tensors["dec_embeds"],
patch_embeds=tensors["patch_embeds"],
tokens=tensors["decoder_tokens"],
cross_mask=None,
cache=None,
)
assert output is not None
expected_shape = (
tensors["decoder_tokens"].shape[0],
tensors["decoder_tokens"].shape[1],
args.vocab_size,
)
assert output.shape == expected_shape
assert cache_decoder is None
def test_global_transformer(self):
args = create_args()
device = torch.device("cuda")
global_transformer = create_global_transformer(args).to(device)
test_files = {
"global_embeds": "global_embeds.pt",
"global_tokens": "global_tokens.pt",
}
tensors = {
name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
for name, filename in test_files.items()
}
h, cache = global_transformer(
embeds=tensors["global_embeds"], tokens=tensors["global_tokens"]
)
h is not None
assert h.shape == (2, 256, 512)
assert cache is None
def test_blt_transformer_init(self):
args = create_args()
model = ByteLatentTransformer(args)
assert model is not None
@pytest.mark.parametrize("attn_type", ["fmha", "sdpa"])
def test_blt_transformer_forward(self, attn_type):
args = create_args()
args = args.model_copy(update=dict(efficient_attn=attn_type))
model = ByteLatentTransformer(args)
model = model.cuda()
batch = fake_batch()
x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
output = model(
tokens=x,
patch_lengths=patch_lengths,
ngram_ids=ngram_ids,
)
assert output is not None
expected_shape = (
x.shape[0],
x.shape[1],
args.vocab_size,
)
assert output.shape == expected_shape
def test_blt_transformer_cross_attn_forward(self):
args = create_args(cross_attention=True)
model = ByteLatentTransformer(args)
model = model.cuda()
batch = fake_batch()
x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
output = model(
tokens=x,
patch_lengths=patch_lengths,
ngram_ids=ngram_ids,
)
assert output is not None
expected_shape = (
x.shape[0],
x.shape[1],
args.vocab_size,
)
assert output.shape == expected_shape
def test_cross_attention_rand(self):
x = torch.randn(2, 256, 512, device="cuda")
kv = torch.randn(2, 256, 512, device="cuda")
cross_attention = CrossAttention(
dim=512,
head_dim=64,
n_heads=8,
n_kv_heads=4,
norm_eps=1e-6,
).to("cuda")
mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None)
output = cross_attention(x, kv, mask)
assert output is not None
assert output.shape == (2, 256, 512)
def test_ngram_embeddings(self):
ngram_to_size = {
2: 38396,
3: 50000,
4: 50000,
5: 50000,
6: 50000,
7: 50000,
8: 50000,
}
batch = fake_batch()
ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size)
ngram_ids = ngram_processor.encode_token_ngrams(batch.x)
ngram_ids = np.stack(ngram_ids, axis=0)
batch = replace(batch, ngram_ids=ngram_ids)
args = create_args(cross_attention=True)
args = args.model_copy(
update=dict(
encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000",
encoder_enable_byte_ngrams=True,
ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes,
)
)
model = ByteLatentTransformer(args)
model = model.cuda()
x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
output = model(
tokens=x,
patch_lengths=patch_lengths,
ngram_ids=ngram_ids,
)
assert output is not None
expected_shape = (
x.shape[0],
x.shape[1],
args.vocab_size,
)
assert output.shape == expected_shape
def test_loss_backward(self):
args = create_args()
args = args.model_copy(update=dict(efficient_attn="sdpa"))
batch = fake_batch()
model = ByteLatentTransformer(args)
steps = 10
optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps)
model = model.cuda()
x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)
initial_loss = None
final_loss = None
for step in range(steps):
output = model(
tokens=x,
patch_lengths=patch_lengths,
ngram_ids=ngram_ids,
)
loss, _ = compute_loss(output, y, mask, 1.0)
if step == 0:
initial_loss = loss.item()
if step == steps - 1:
final_loss = loss.item()
prev_loss = loss.item()
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
assert (
final_loss < initial_loss
), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}"

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
import torch
from bytelatent.constants import BLT_DATA
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum, entropy
from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
ENTROPY_MODEL = "transformer_100m"
ARROW_TEST_DATA = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
def test_entropy_model():
initial_state = ArrowFileIteratorState(
file_path=None,
num_workers=1,
worker_id=0,
preprocess_dir=None,
entropy_model_name=ENTROPY_MODEL,
dataset_files=[ARROW_TEST_DATA],
row_num=0,
arrow_batch_size=100,
)
arrow_file = initial_state.build()
tokenizer_args = TokenizerArgs(
name="blt",
init_kwargs={
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
entropy_model = load_entropy_model(
BLT_DATA / "checkpoint_0100000_consolidated",
os.path.join(
BLT_DATA,
"entropy_model.pth",
),
)
preprocess_iter = PreprocessIterator(
arrow_file,
tokenizer_args=tokenizer_args,
patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.entropy),
add_patches=False,
)
for example in preprocess_iter.create_iter():
tokens = torch.tensor(example.tokens).unsqueeze(0)
expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
preds = entropy_model(tokens)
pred_entropies = entropy(preds)
assert pred_entropies.shape == expected_entropies.shape
assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5)
break

View file

@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import abc
class Tokenizer(abc.ABC):
@abc.abstractmethod
def encode(self, text: str, add_bos: bool, add_eos: bool):
pass
@abc.abstractmethod
def decode(self, tokens: list[int]):
pass
@abc.abstractmethod
def get_token_offsets(
self, text: str, tokens: list[int] | None = None
) -> tuple[list[str], list[int]]:
"""Return the offsets of the tokens in the original text. Only used for evaluation."""
pass

View file

@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import re
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.tokenizers.constants import (
BOE_ID,
BOS_ID,
BPE_ID,
BYTE_UNITS,
EOS_ID,
OFFSET,
PAD_ID,
)
from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
def convert_to_bytes(s):
# check if the output is a bytes like object of the format <0x00>
if re.match(r"<0x[0-9a-fA-F]+>", s):
return bytes.fromhex(s[3:-1])
else:
return bytes(s, "utf-8", errors="ignore")
def text2bytes_bpe_delims(
text: str,
*,
bpe_tokenizer,
bpe_id: int,
offsetting_special_char: int,
add_bos: bool,
add_eos: bool,
):
cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos)
# merge the leading space tokens
leading_space_tokens = []
other_bpe_tokens = []
leading = True
for token in cur_bpe:
bpe_str = bpe_tokenizer.sp_model.id_to_piece(token)
if leading and all(c == "" for c in bpe_str):
leading_space_tokens.append(bpe_str)
else:
leading = False
other_bpe_tokens.append(bpe_str)
cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens
# Remove the '▁' characters
bpe_strs = []
for i, bpe_str in enumerate(cur_bpe_strs):
if (
len(bpe_strs) <= 1
and all([c == " " for s in bpe_strs for c in s])
and not all(c == "" for c in bpe_str)
):
# Remove leading space for first non space token.
bpe_str = bpe_str.replace("", "")
elif i == 0 and all(c == "" for c in bpe_str):
bpe_str = " " * (len(text) - len(text.lstrip(" ")))
else:
bpe_str = bpe_str.replace("", " ")
if len(bpe_str) > 0:
bpe_strs.append(bpe_str)
ex_seq = []
# Convert bpe tokens to bytes
for s in bpe_strs:
byte_chunk = convert_to_bytes(s)
proc_chunk = [int(unit) for unit in byte_chunk]
ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk)
return ex_seq
class BltTokenizer(Tokenizer):
def __init__(
self,
*,
vocab_size_unit_1: int = BYTE_UNITS,
bpe_delim: bool = False,
bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
add_bos: bool = True,
add_eos: bool = True,
):
self.add_bos = add_bos
self.add_eos = add_eos
self.vocab_size_unit_1 = vocab_size_unit_1
self.boe_id = BOE_ID
self.bos_id = BOS_ID
self.eos_id = EOS_ID
self.pad_id = PAD_ID
self.bpe_id = BPE_ID
self.bpe_tokenizer_path = bpe_tokenizer_path
if bpe_delim:
self.bpe_tokenizer = SentencePieceTokenizer(
model_path=self.bpe_tokenizer_path
)
else:
self.bpe_tokenizer = None
self.bpe_delim = bpe_delim
self.offsetting_special_char = OFFSET
self.vocab_size_unit_1 = vocab_size_unit_1
self.n_words = vocab_size_unit_1 + self.offsetting_special_char
def encode(
self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
):
if add_bos is None:
add_bos = self.add_bos
if add_eos is None:
add_eos = self.add_eos
if self.bpe_delim:
tokens = text2bytes_bpe_delims(
text,
bpe_tokenizer=self.bpe_tokenizer,
bpe_id=self.bpe_id,
offsetting_special_char=self.offsetting_special_char,
add_bos=False,
add_eos=False,
)
else:
tokens = bytes(text, encoding="utf-8", errors="ignore")
# Offsetting
tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
if add_bos:
tokens.insert(0, self.bos_id)
if add_eos:
tokens.append(self.eos_id)
return tokens
def decode(self, tokens: list[int], cut_at_eos: bool = False):
if cut_at_eos:
for k, t in enumerate(tokens):
if t == self.eos_id:
tokens = tokens[: k + 1]
break
return bytes(
[
tok - self.offsetting_special_char
for tok in tokens
if tok - self.offsetting_special_char >= 0
]
).decode("utf-8", errors="ignore")
def get_token_offsets(self, text: str, tokens: list[int] | None = None):
# TODO: Figure out what this does
raise NotImplementedError()

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from typing import Any
from pydantic import BaseModel
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer
from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer
try:
from sentencepiece import SentencePieceProcessor
has_sp = True
except ImportError:
has_sp = False
try:
import tiktoken
from tiktoken.load import load_tiktoken_bpe
has_tiktoken = True
except ImportError:
has_tiktoken = False
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
logger = logging.getLogger(__name__)
class MockTokenizer(Tokenizer):
n_words: int = 256
def encode(self, text: str, add_bos: bool, add_eos: bool):
return text
def decode(self, tokens):
raise NotImplementedError()
def get_token_offsets(
self, text: str, tokens: list[int] | None = None
) -> tuple[list[str]]:
raise NotImplementedError()
class TokenizerArgs(BaseModel):
name: str = "bytes"
init_kwargs: dict[str, Any] | None = None
def build(self) -> Tokenizer:
if self.init_kwargs is None:
init_kwargs = {}
else:
init_kwargs = self.init_kwargs
if self.name == "blt":
return BltTokenizer(**init_kwargs)
elif self.name == "bytes":
return ByteTokenizer(**init_kwargs)
elif self.name == "mock":
return MockTokenizer(**init_kwargs)
elif self.name == "sp":
assert has_sp, "sentencepiece not installed"
return SentencePieceTokenizer(**init_kwargs)
elif self.name == "tiktoken":
assert has_tiktoken, "tiktoken not installed"
return TikTokenTokenizer(**init_kwargs)
else:
raise NotImplementedError(f"{self.name} tokenizer type is not implemented")

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
class ByteTokenizer(Tokenizer):
def __init__(self):
self.bos_id = 256
self.eos_id = 257
self.n_words = 258
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
return tokens
def decode(self, tokens: list[int]):
byte_tokens = bytes([t for t in tokens if t < 256])
return byte_tokens.decode("utf-8", errors="backslashreplace")
def get_token_offsets(
self, text: str, tokens: list[int] | None = None
) -> tuple[list[str], list[int]]:
if tokens is None:
tokens = self.encode(text)
decoded_chars, offsets = [], []
byte_pos = 0
for token in tokens:
if token < 256:
char = bytes([token]).decode("utf-8", errors="ignore")
if char:
decoded_chars.append(char)
offsets.append(byte_pos)
byte_pos += len(char.encode("utf-8"))
return decoded_chars, offsets

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
SEP = " "
BOS_ID: int = 1
EOS_ID: int = 2
PAD_ID: int = -1
BOE_ID: int = 0
BPE_ID: int = 3
OFFSET: int = 4
BYTE_UNITS: int = 256

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
try:
from sentencepiece import SentencePieceProcessor
has_sp = True
except ImportError:
has_sp = False
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
logger = logging.getLogger(__name__)
class SentencePieceTokenizer(Tokenizer):
def __init__(
self, model_path: str, add_bos: bool = True, add_eos: bool = True
) -> None:
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
logger.info(f"Reloaded SentencePiece model from {model_path}")
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
self.add_bos = add_bos
self.add_eos = add_eos
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
if add_bos is None:
add_bos = self.add_bos
if add_eos is None:
add_eos = self.add_eos
assert type(s) is str
tokens = (
[self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos
)
return tokens
def decode(self, tokens: list[int]):
return self.sp_model.decode(tokens)
def get_token_offsets(
self, text: str, tokens: list[int] | None = None
) -> tuple[list[str], list[int]]:
pieces = self.sp_model.encode_as_immutable_proto(text).pieces
substrs = [p.surface for p in pieces]
offsets = [p.begin for p in pieces]
return substrs, offsets

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
from bytelatent.constants import BLT_DATA
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def test_tokenizer_bytes():
with open("fixtures/tokenizer_data.json") as f:
data = json.load(f)
examples: list[str] = data["texts"]
examples_tokens: list[list[int]] = data["tokens"]
tokenizer = BltTokenizer(bpe_delim=False)
for i in range(len(examples)):
assert tokenizer.encode(examples[i]) == examples_tokens[i]
def test_tokenizer_bpe():
with open("fixtures/tokenizer_data_bpe_delim.json") as f:
data = json.load(f)
examples: list[str] = data["texts"]
examples_tokens: list[list[int]] = data["tokens"]
tokenizer = BltTokenizer(bpe_delim=True)
for i in range(len(examples)):
assert tokenizer.encode(examples[i]) == examples_tokens[i]
def test_build_tokenizer_from_args():
tokenizer_args = TokenizerArgs(
name="blt",
init_kwargs={
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
tokenizer = tokenizer_args.build()
assert tokenizer.encode("test text") is not None

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from copy import copy
from pathlib import Path
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
try:
import tiktoken
from tiktoken.load import load_tiktoken_bpe
has_tiktoken = True
except ImportError:
has_tiktoken = False
DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
DEFAULT_TIKTOKEN_SPECIAL_TOKENS = {
"<|begin_of_text|>": 0,
"<|end_of_text|>": 1,
"<|fim_prefix|>": 2,
"<|fim_middle|>": 3,
"<|fim_end_fill|>": 253,
"<|fim_pad|>": 254,
"<|fim_suffix|>": 255,
}
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
logger = logging.getLogger(__name__)
class TikTokenTokenizer(Tokenizer):
def __init__(self, model_path: str) -> None:
mergeable_ranks = load_tiktoken_bpe(model_path)
all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS)
missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values())
for id in missing_ids:
all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id
for name in all_special_tokens_with_ids:
all_special_tokens_with_ids[name] += len(mergeable_ranks)
self.tkt_model = tiktoken.core.Encoding(
name=Path(model_path).stem,
pat_str=DEFAULT_TIKTOKEN_PATTERN,
mergeable_ranks=mergeable_ranks,
special_tokens=all_special_tokens_with_ids,
)
self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>")
self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>")
self.n_words: int = self.tkt_model.n_vocab
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
def encode(self, s: str, add_bos: bool, add_eos: bool):
assert isinstance(s, str)
subs = []
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
return (
[self.bos_id] * add_bos
+ sum(self.tkt_model.encode_ordinary_batch(subs), start=[])
+ [self.eos_id] * add_eos
)
def decode(self, tokens: list[int]):
return self.tkt_model.decode(tokens)
def get_token_offsets(
self, text: str, tokens: list[int] | None = None
) -> tuple[list[str], list[int]]:
if tokens is not None:
token_bytes = self.tkt_model.decode_tokens_bytes(tokens)
else:
token_bytes = self.tkt_model.decode_tokens_bytes(
self.tkt_model.encode(text, allowed_special="all")
)
text_len, offsets = 0, []
for token in token_bytes:
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])]
return substrs, offsets

651
bytelatent/train.py Normal file
View file

@ -0,0 +1,651 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import gc
import logging
import os
import sys
from contextlib import ExitStack
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from timeit import default_timer as timer
from typing import Any, Dict, Type, TypeVar
import torch
import torch.distributed
import torch.nn.functional
import torch.nn.functional as F
import wandb
import xformers.profiler
from omegaconf import OmegaConf
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import lr_scheduler
from bytelatent.args import TrainArgs
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.data_types import DataLoaderState
from bytelatent.distributed import (
check_model_value_range,
clean_env,
dist_mean_dict,
get_device_mesh,
get_is_master,
get_world_size,
init_signal_handler,
parallelize_model,
requeue_slurm_job,
setup_env,
setup_torch_distributed,
)
from bytelatent.logger import init_logger
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.optim import build_optimizer
from bytelatent.probe import AutoProbeD
from bytelatent.profiling import maybe_run_profiler
from bytelatent.stool import StoolArgs, launch_job
from bytelatent.transformer import (
build_fsdp_grouping_plan,
get_no_recompute_ops,
get_num_flop_per_token,
tp_parallelize,
)
logger = logging.getLogger()
T = TypeVar("T")
def flatten_dict(d, parent_key="", sep="_"):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
"""
Converts a dictionary to a dataclass instance, recursively for nested structures.
"""
base = OmegaConf.structured(cls())
OmegaConf.set_struct(base, strict)
override = OmegaConf.create(data)
return OmegaConf.to_object(OmegaConf.merge(base, override))
@dataclass
class TrainState(Stateful):
step: int # Nb of steps taken by the optimizer
acc_step: int # Nb of accumulation steps done since last optimizer step
scheduler: lr_scheduler.LambdaLR
data_loader_state: DataLoaderState
scale: float = 1.0
def state_dict(self) -> Dict[str, Any]:
return {
"step": self.step,
"acc_step": self.acc_step,
"data_loader_state": self.data_loader_state.dict(),
"scheduler": self.scheduler.state_dict(),
}
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
self.acc_step = state_dict["acc_step"]
self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"])
self.scheduler.load_state_dict(state_dict["scheduler"])
def validate_train_args(args: TrainArgs, output_size: int):
if args.model.vocab_size < 0:
logger.info(f"Setting model output size to {args.model.vocab_size}")
args.model.vocab_size = output_size
assert args.dump_dir, "Dump dir not set"
if args.checkpoint.path is None:
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
for source in args.data.sources:
data_path = os.path.join(args.data.root_dir, source)
assert os.path.exists(data_path), f"{data_path} doesn't exist"
if (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
!= get_world_size()
):
assert get_world_size() % args.distributed.dp_shard == 0
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
args.distributed.dp_replicate = (
args.distributed.dp_replicate // args.distributed.tp_size
)
logger.warning(
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
)
assert (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
== get_world_size()
)
if args.distributed.fsdp_type == "no_shard":
assert (
args.distributed.dp_shard == 1
and args.distributed.dp_replicate == get_world_size()
)
args.model.max_seqlen = args.data.seq_len
if args.distributed.tp_size == 1:
logger.warning(
"Tensor parallelism has not been tested for a while, use at your own risk"
)
assert (
args.probe_freq != args.profiling.mem_steps
), "Don't profile during probe step"
assert (
args.probe_freq != args.profiling.profile_steps
), "Don't profile during probe step"
if args.logging.wandb is not None:
args.logging.wandb.name = args.name
if args.probe_freq is not None:
assert (
args.distributed.tp_size == 1
), "Probing not supported with tensor parallelism"
assert (
args.distributed.selective_activation_checkpointing is False
), "Probing not supported with selective activation checkpointing"
preemption_flag = dict(flag=False)
def set_preemption_flag(signum, frame):
logger.warning("Signal handler called with signal " + str(signum))
logger.warning("Preemption ! checkpointing asap and exiting.")
preemption_flag["flag"] = True
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
test = train_state.step % freq == 0
if acc_step is not None:
test = test and (train_state.acc_step == acc_step)
elif acc_freq is not None:
test = test and ((train_state.acc_step % acc_freq) == 0)
return test
def compute_loss(p, y, mask, scale):
tok_loss = scale * F.cross_entropy(
p.flatten(0, 1), y.flatten(0, 1), reduction="none"
)
if mask is None:
loss = tok_loss.mean()
else:
mask = mask.flatten(0, 1)
tok_loss = tok_loss * mask
loss = tok_loss.sum() / (mask.sum() + 1e-6)
return loss, tok_loss
def train(args: TrainArgs):
with ExitStack() as context_stack:
tokenizer = args.data.tokenizer_args.build()
validate_train_args(
args,
tokenizer.n_words,
)
if get_is_master():
os.makedirs(args.dump_dir, exist_ok=True)
args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml")
init_logger(Path(args.dump_dir) / "train.log")
init_signal_handler(set_preemption_flag) # For handling preemption signals.
setup_env(args.env)
setup_torch_distributed(args.distributed)
world_mesh = get_device_mesh(args.distributed)
logger.info(f"Starting job: {args.name}")
# build dataloader
# need dp world size and rank
dp_mesh = world_mesh["dp_replicate"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
if args.distributed.dp_shard > 1:
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
dp_degree *= world_mesh["dp_shard"].size()
logger.info(f"Running on dp rank : {dp_rank}")
logger.info(f"Running on dp size : {dp_degree}")
torch.manual_seed(args.seed)
logger.info("Building model")
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
with torch.device("meta"):
model = ByteLatentTransformer(args.model)
logger.info("Model is built !")
model_param_count = get_num_params(model)
model = parallelize_model(
model,
world_mesh,
args.model,
args.distributed,
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
tp_parallelize=tp_parallelize,
no_recompute_ops=get_no_recompute_ops(),
)
# Once we shard the model on different gpus we can actually initialize the model
# First we create empty tensors of the correct shapes
model = model.to_empty(device="cuda")
# Then we init the model. Please make sure this function initializes *ALL* parameters
# and buffers, otherwise you will have random values in the unitialized tensors
# which will silently fail (give nan gradients for example)
if args.checkpoint.init_ckpt_path:
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
load_from_checkpoint(
args.checkpoint.init_ckpt_path, model, model_key="model"
) # Put model_key="" if its directly the model checkpoint
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
else:
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
torch.manual_seed(args.model.seed)
model.init_weights()
check_model_value_range(model, range=10.0, std=1.0)
# log model size
logger.info(f"Model size: {model_param_count:,} total parameters")
gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
)
logger.info(f"GPU memory usage: {gpu_memory_monitor}")
# build optimizer after apply parallelisms to the model
optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
data_loader = args.data.build_from_rank(dp_rank, dp_degree)
data_loader_state = data_loader.get_state()
train_state = TrainState(
step=0,
acc_step=0,
data_loader_state=data_loader_state,
scheduler=scheduler,
scale=1.0,
)
checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
checkpoint.load(model, optimizer, train_state, world_mesh)
# Either load from latest checkpoint or start from scratch
if args.probe_freq is not None:
if get_is_master():
os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
torch.distributed.barrier()
probe = AutoProbeD(
model,
(
Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
if (dp_rank % 128 == 0)
else None
),
)
probe_mod = model._orig_mod if args.distributed.compile else model
gc.disable()
# train loop
model.train()
metric_logger = context_stack.enter_context(
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
)
data_loader = train_state.data_loader_state.build()
batch_iterator = data_loader.create_iter()
torch_profiler = context_stack.enter_context(
maybe_run_profiler(args.dump_dir, model, args.profiling)
)
nwords_since_last_log = 0
time_last_log = timer()
gc.collect()
while train_state.step < args.steps:
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
train_state.acc_step += 1
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
# get batch
curr_lr = float(optimizer.param_groups[0]["lr"])
data_load_start = timer()
batch = next(batch_iterator)
batch_x = torch.from_numpy(
batch.x,
).cuda()
batch_y = torch.from_numpy(batch.y).cuda()
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None:
raise ValueError(
"Cannot enable byte ngrams and have batch.ngram_ids be None"
)
ngram_ids = (
None
if batch.ngram_ids is None
else torch.from_numpy(batch.ngram_ids).cuda()
)
if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
logger.info("garbage collection")
# we do garbage collection manually otherwise different processes
# run the GC at different times so they slow down the whole pipeline
gc.collect()
data_load_time = round(timer() - data_load_start, 4)
nwords_since_last_log += batch_x.numel()
bsz, seqlen = batch_y.shape
# forward
start_timer = torch.cuda.Event(enable_timing=True)
end_timer = torch.cuda.Event(enable_timing=True)
start_timer.record()
# This is an automatic probe that will compute statistics
# of all linears' inputs, weights and outputs
# along with attention logits and entropy
# both in forward and backward pass
tok_loss = None
if (args.probe_freq is not None) and every_n_steps(
train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps
):
# Here we do a fake forward and backward pass on a smaller
# batch size to avoid OOM
# This assumes the model has no stateful layers (batch norm..)
assert (
next(probe_mod.parameters()).grad is None
), "Can't probe model if grads are not reset"
with probe:
probe.metadata = {
"it": train_state.step,
"global_step": train_state.step,
"loop": "lingua",
}
# Non compiled model uses roughly 2x memory in our exps
# So we divide bsz by 2 or seqlen by 2
probe_bsz = max(1, bsz // 2)
probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2)
probe_loss = probe_mod(
batch_x[:probe_bsz, :probe_seq],
batch_y[:probe_bsz, :probe_seq],
)
probe_loss.backward()
# We zero grads to cancel this fake step
optimizer.zero_grad()
assert (
next(probe_mod.parameters()).grad is None
), "Probe model shouldn't have grads at this point"
pred = model(
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
)
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
loss = loss / args.grad_acc_steps
# backward on scaled loss to create scaled gradients
loss.backward()
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
)
grad_norm = (
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
).item()
# optimizer step
if train_state.acc_step == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
train_state.step += 1
# updates the scale for next iteration
# training iteration complete
end_timer.record()
torch.cuda.synchronize()
curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
# if profiler is active
if torch_profiler:
xformers.profiler.step()
# log metrics
if every_n_steps(
train_state,
args.logging.freq,
acc_step=None if args.logging.acc_freq else 0,
acc_freq=args.logging.acc_freq,
):
time_delta = timer() - time_last_log
wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
total_acc_steps = (
args.grad_acc_steps * train_state.step + train_state.acc_step
)
tokens_per_gpu = (
total_acc_steps * args.data.batch_size * args.data.seq_len
)
total_tokens = dp_degree * tokens_per_gpu
# This is an estimate and the correct values may change
# if you change the architecture
# Use xformer's analyze profile trace to get actual measurement
FLOPS = (
get_num_flop_per_token(
model_param_count - args.model.vocab_size * args.model.dim,
args.model.n_layers,
args.model.dim,
args.data.seq_len,
)
* wps
)
metrics = flatten_dict(
{
"global_step": train_state.step,
"acc_step": train_state.acc_step,
"speed": {
"wps": wps,
"FLOPS": FLOPS,
"curr_iter_time": curr_iter_time,
"data_load_time": data_load_time,
},
"optim": {
"grad_norm": grad_norm,
"lr": curr_lr,
"total_tokens": total_tokens,
},
"memory": gpu_mem_stats._asdict(),
},
sep="/",
)
to_sync = {}
to_sync["loss/out"] = loss.item()
metrics.update(dist_mean_dict(to_sync))
if get_is_master():
metric_logger.log(metrics)
gpu_memory_monitor.reset_peak_stats()
nwords_since_last_log = 0
time_last_log = timer()
logger.info(
f"step: {train_state.step}"
f" acc: {train_state.acc_step}"
f" loss: {round(loss.item(),4):>7}"
f" grad: {grad_norm:.2e}"
f" flops: {FLOPS:.2e}"
f" wps: {wps:.2e}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5}"
f" lr: {curr_lr:.2e}"
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
f" pow: {gpu_mem_stats.power_draw/1000} W"
)
saved = False
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
saved = checkpoint.save(
model,
optimizer,
train_state,
args,
device_mesh=world_mesh,
)
if args.eval is not None and every_n_steps(
train_state, args.checkpoint.eval.every, acc_step=0
):
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
eval_args = dataclass_from_dict(EvalArgs, args.eval)
eval_args.global_step = train_state.step
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
eval_args.dump_dir = str(
os.path.join(
args.dump_dir,
"evals",
EVAL_FOLDER_NAME.format(train_state.step),
)
)
eval_args.metric_log_dir = args.dump_dir
if args.async_eval_gpus is None:
launch_eval(eval_args)
elif get_is_master():
if wandb.run is not None and args.logging.wandb is not None:
eval_args.wandb = deepcopy(args.logging.wandb)
assert args.async_eval_gpus > 0
logger.info(f"Launching evals on {args.async_eval_gpus} gpus")
with clean_env():
launch_job(
StoolArgs(
asdict(eval_args),
script="apps.main.eval",
copy_code=False,
nodes=args.async_eval_gpus // 8,
qos="lowest",
)
)
if preemption_flag["flag"]:
if not saved:
checkpoint.save(
model,
optimizer,
train_state,
args,
device_mesh=world_mesh,
)
requeue_slurm_job()
sys.exit(0)
if not saved:
checkpoint.save(
model,
optimizer,
train_state,
args,
device_mesh=world_mesh,
)
gc.collect()
def main():
"""
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
model: LMTransformerArgsgs
@dataclass
class LMTransformerArgsgs:
dim: int
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
or just name=tictac for top level attributes.
The behavior here is as follows:
1. We instantiate TrainArgs with its default values
2. We override those default values with the ones in the provided config file
3. We override the result with the additional arguments provided through command line
For example, if the config is the following
model:
dim: 128
n_layers: 4
and you call train.py with train.py model.dim=64
Then the final TrainArgs will have
model:
dim: 64
n_layers: 4
Plus all the default values in TrainArgs dataclass.
"""
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.create(TrainArgs().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
train_args = TrainArgs.model_validate(cfg)
train(train_args)
if __name__ == "__main__":
main()

216
bytelatent/transformer.py Normal file
View file

@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from xformers.ops import AttentionBias, fmha
from bytelatent.base_transformer import (
BaseTransformer,
BaseTransformerArgs,
RMSNorm,
cross_entropy,
)
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
elif attn_impl == "sdpa":
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
else:
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)
def attention_flops_per_token(n_layers, seq_len, dim, causal):
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
def get_num_flop_per_token(
num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
) -> int:
return 6 * num_non_embed_params + attention_flops_per_token(
n_layers, seq_len, dim, True
)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
class LMTransformerArgs(BaseTransformerArgs):
seed: int = 42
vocab_size: int = -1
weight_tying: bool = False
sliding_window: int | None = None
class LMTransformer(BaseTransformer):
def __init__(self, args: LMTransformerArgs):
super().__init__(args)
self.weight_tying = args.weight_tying
self.sliding_window = args.sliding_window
assert args.vocab_size > 0
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(
args.dim,
args.vocab_size,
bias=False,
)
if args.weight_tying:
self.output.weight = self.embeddings.tok_embeddings.weight
def forward(
self,
token_values: torch.Tensor,
target: Optional[torch.Tensor] = None,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
attn_impl: str = "sdpa",
):
bsz, seqlen = token_values.shape
h = self.tok_embeddings(token_values)
mask = (
mask
if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
logits = self.output(self.norm(h))
if target is not None:
return cross_entropy(logits, target)
else:
return logits
def reset_parameters(self, init_std=None):
# Either use fixed base std or sqrt model dim
super().reset_parameters()
init_std = init_std or (self.dim ** (-0.5))
self.norm.reset_parameters()
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if not self.weight_tying:
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
def get_no_recompute_ops():
return None
# Optional and only used for fully shard options (fsdp) is choose. Highly recommanded for large models
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
group_plan: Tuple[int, bool] = []
# Grouping and output seperately
group_plan.append(("tok_embeddings", False))
# Grouping by layers
for i in range(model_args.n_layers):
group_plan.append((f"layers.{i}", False))
group_plan.append(("output", True))
return group_plan
# Optional and only used for model/tensor parallelism when tp_size > 1
def tp_parallelize(model, tp_mesh, model_args: LMTransformerArgs, distributed_args):
assert model_args.dim % distributed_args.tp_size == 0
assert model_args.vocab_size % distributed_args.tp_size == 0
assert model_args.n_heads % distributed_args.tp_size == 0
assert (model_args.n_kv_heads or 0) % distributed_args.tp_size == 0
assert model_args.n_heads % (model_args.n_kv_heads or 1) == 0
# Embedding layer tp
main_plan = {}
main_plan["tok_embeddings"] = ColwiseParallel(
input_layouts=Replicate(), output_layouts=Shard(1)
)
main_plan["norm"] = SequenceParallel()
main_plan["output"] = ColwiseParallel(
input_layouts=Shard(1), output_layouts=Replicate()
)
parallelize_module(
model,
tp_mesh,
main_plan,
)
# Attention layers tp
for layer in model.layers:
layer_plan = {}
layer_plan["attention"] = PrepareModuleInput(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
)
layer_plan["attention_norm"] = SequenceParallel()
layer_plan["attention.wq"] = ColwiseParallel()
layer_plan["attention.wk"] = ColwiseParallel()
layer_plan["attention.wv"] = ColwiseParallel()
layer_plan["attention.wo"] = RowwiseParallel(output_layouts=Shard(1))
# Feedforward layers tp
layer_plan["feed_forward"] = PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
)
layer_plan["ffn_norm"] = SequenceParallel()
layer_plan["feed_forward.w1"] = ColwiseParallel()
layer_plan["feed_forward.w3"] = ColwiseParallel()
layer_plan["feed_forward.w2"] = RowwiseParallel(output_layouts=Shard(1))
parallelize_module(
layer,
tp_mesh,
layer_plan,
)
# Adjusting the number of heads and kv heads according to the tp size
attn_layer = layer.attention
attn_layer.n_heads = attn_layer.n_heads // distributed_args.tp_size
attn_layer.n_kv_heads = attn_layer.n_kv_heads // distributed_args.tp_size

3
dev/lint.sh Normal file
View file

@ -0,0 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
isort .
black .

View file

@ -0,0 +1 @@
{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 80, 105, 120, 43, 119, 36, 103, 108, 105, 103, 111, 36, 109, 106, 36, 120, 108, 105, 119, 105, 36, 120, 115, 111, 105, 114, 109, 126, 105, 118, 119, 36, 113, 101, 120, 103, 108, 37, 2]]}

View file

@ -0,0 +1 @@
{"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 3, 80, 105, 120, 3, 43, 3, 119, 3, 36, 103, 108, 105, 103, 111, 3, 36, 109, 106, 3, 36, 120, 108, 105, 119, 105, 3, 36, 120, 115, 111, 105, 114, 3, 109, 126, 105, 118, 119, 3, 36, 113, 101, 120, 103, 108, 3, 37, 2]]}

View file

@ -0,0 +1 @@
{"position":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":6,"7":7,"8":8,"9":9,"10":10,"11":11,"12":12,"13":13,"14":14,"15":15,"16":16,"17":17,"18":18,"19":19,"20":20,"21":21,"22":22,"23":23,"24":24,"25":25,"26":26,"27":27,"28":28,"29":29,"30":30,"31":31,"32":32,"33":33,"34":34,"35":35,"36":36,"37":37,"38":38,"39":39,"40":40,"41":41,"42":42,"43":43,"44":44,"45":45,"46":46,"47":47,"48":48,"49":49,"50":50,"51":51,"52":52,"53":53,"54":54,"55":55,"56":56,"57":57,"58":58,"59":59,"60":60,"61":61,"62":62,"63":63,"64":64,"65":65,"66":66,"67":67,"68":68,"69":69,"70":70,"71":71,"72":72,"73":73,"74":74,"75":75,"76":76,"77":77,"78":78,"79":79,"80":80},"tokens":{"0":"<","1":"D","2":"a","3":"e","4":"n","5":"e","6":"r","7":"y","8":"s","9":"_","10":"T","11":"a","12":"r","13":"g","14":"a","15":"r","16":"y","17":"e","18":"n","19":"_","20":"i","21":"s","22":"_","23":"i","24":"n","25":"_","26":"G","27":"a","28":"m","29":"e","30":"_","31":"o","32":"f","33":"_","34":"T","35":"h","36":"r","37":"o","38":"n","39":"e","40":"s","41":",","42":"_","43":"a","44":"_","45":"f","46":"a","47":"n","48":"t","49":"a","50":"s","51":"y","52":"_","53":"e","54":"p","55":"i","56":"c","57":"_","58":"b","59":"y","60":"_","61":"G","62":"e","63":"o","64":"r","65":"g","66":"e","67":"_","68":"R","69":".","70":"R","71":".","72":"_","73":"M","74":"a","75":"r","76":"t","77":"i","78":"n","79":".","80":">"},"token_ids":{"0":1,"1":72,"2":101,"3":105,"4":114,"5":105,"6":118,"7":125,"8":119,"9":36,"10":88,"11":101,"12":118,"13":107,"14":101,"15":118,"16":125,"17":105,"18":114,"19":36,"20":109,"21":119,"22":36,"23":109,"24":114,"25":36,"26":75,"27":101,"28":113,"29":105,"30":36,"31":115,"32":106,"33":36,"34":88,"35":108,"36":118,"37":115,"38":114,"39":105,"40":119,"41":48,"42":36,"43":101,"44":36,"45":106,"46":101,"47":114,"48":120,"49":101,"50":119,"51":125,"52":36,"53":105,"54":116,"55":109,"56":103,"57":36,"58":102,"59":125,"60":36,"61":75,"62":105,"63":115,"64":118,"65":107,"66":105,"67":36,"68":86,"69":50,"70":86,"71":50,"72":36,"73":81,"74":101,"75":118,"76":120,"77":109,"78":114,"79":50,"80":2},"entropies":{"0":3.3949158192,"1":2.1656451225,"2":2.3216569424,"3":2.8214058876,"4":1.5249242783,"5":0.0401624143,"6":0.0981037766,"7":0.0544578359,"8":0.3430138826,"9":1.0546212196,"10":0.25252828,"11":0.1494535804,"12":0.0624754503,"13":0.001355894,"14":0.0050173439,"15":0.0052358187,"16":0.0011725067,"17":0.0010307421,"18":1.0241208076,"19":3.6867966652,"20":0.4502205253,"21":0.0484119244,"22":2.2572875023,"23":0.3789347112,"24":1.0042934418,"25":2.9090054035,"26":1.8933598995,"27":1.3859074116,"28":0.3827198744,"29":0.2646365762,"30":1.7742085457,"31":0.0136727821,"32":0.0053820172,"33":0.5485631227,"34":0.2064044327,"35":0.0049266233,"36":0.0005439016,"37":0.0007023578,"38":0.0004170335,"39":0.0054524317,"40":1.1938130856,"41":0.0238215197,"42":3.1279797554,"43":1.3883389235,"44":3.0503094196,"45":1.695879817,"46":1.8551058769,"47":1.4570231438,"48":0.0047810897,"49":0.026396824,"50":0.6633765101,"51":0.3141393065,"52":2.8411159515,"53":1.143143177,"54":0.0520330966,"55":0.3398066461,"56":0.4140175879,"57":2.5563707352,"58":1.3370712996,"59":0.0227173548,"60":3.4447185993,"61":1.8576486111,"62":0.8189754486,"63":0.6776530743,"64":0.0677763447,"65":0.212713033,"66":0.1003480032,"67":0.1746164262,"68":0.4123829603,"69":0.5507118702,"70":0.1047425047,"71":0.0194335245,"72":0.001482119,"73":0.0009310447,"74":0.0002176317,"75":0.0076908777,"76":0.0003866984,"77":0.0008008487,"78":1.2395234108,"79":0.4564163089,"80":0.0000461392},"patch":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":5,"7":5,"8":5,"9":5,"10":5,"11":5,"12":5,"13":5,"14":5,"15":5,"16":5,"17":5,"18":5,"19":5,"20":6,"21":6,"22":6,"23":7,"24":7,"25":7,"26":8,"27":9,"28":10,"29":10,"30":10,"31":11,"32":11,"33":11,"34":11,"35":11,"36":11,"37":11,"38":11,"39":11,"40":11,"41":11,"42":11,"43":12,"44":13,"45":14,"46":15,"47":16,"48":17,"49":17,"50":17,"51":17,"52":17,"53":18,"54":18,"55":18,"56":18,"57":18,"58":19,"59":20,"60":20,"61":21,"62":22,"63":22,"64":22,"65":22,"66":22,"67":22,"68":22,"69":22,"70":22,"71":22,"72":22,"73":22,"74":22,"75":22,"76":22,"77":22,"78":22,"79":22,"80":22},"start":{"0":1,"1":1,"2":1,"3":1,"4":1,"5":1,"6":0,"7":0,"8":0,"9":0,"10":0,"11":0,"12":0,"13":0,"14":0,"15":0,"16":0,"17":0,"18":0,"19":0,"20":1,"21":0,"22":0,"23":1,"24":0,"25":0,"26":1,"27":1,"28":1,"29":0,"30":0,"31":1,"32":0,"33":0,"34":0,"35":0,"36":0,"37":0,"38":0,"39":0,"40":0,"41":0,"42":0,"43":1,"44":1,"45":1,"46":1,"47":1,"48":1,"49":0,"50":0,"51":0,"52":0,"53":1,"54":0,"55":0,"56":0,"57":0,"58":1,"59":1,"60":0,"61":1,"62":1,"63":0,"64":0,"65":0,"66":0,"67":0,"68":0,"69":0,"70":0,"71":0,"72":0,"73":0,"74":0,"75":0,"76":0,"77":0,"78":0,"79":0,"80":0}}

File diff suppressed because one or more lines are too long

5
pyproject.toml Normal file
View file

@ -0,0 +1,5 @@
[tool.isort]
profile = "black"
known_bytelatent = "bytelatent"
known_apps = "apps"
sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER"

22
requirements.txt Normal file
View file

@ -0,0 +1,22 @@
numpy
omegaconf
msgspec
rouge-score
sacrebleu
sentencepiece
tiktoken
fsspec
blobfile
wandb
viztracer
lm-eval
scipy
pynvml
datatrove
orjson
luigi
pydantic
altair
submitit
typer
rich

48
setup/create_env.sh Normal file
View file

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
#SBATCH --job-name=env_creation
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --gres=gpu:8
#SBATCH --exclusive
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=128
#SBATCH --mem=0
#SBATCH --time=01:00:00
# Exit immediately if a command exits with a non-zero status
set -e
# Start timer
start_time=$(date +%s)
# Get the current date
current_date=$(date +%y%m%d)
# Create environment name with the current date
env_prefix=blt_$current_date
# Create the conda environment
source $CONDA_ROOT/etc/profile.d/conda.sh
conda create -n $env_prefix python=3.11 -y -c anaconda
conda activate $env_prefix
echo "Currently in env $(which python)"
# Install packages
pip install torch==2.5.0 xformers --index-url https://download.pytorch.org/whl/cu121
pip install ninja
pip install --requirement requirements.txt
# End timer
end_time=$(date +%s)
# Calculate elapsed time in seconds
elapsed_time=$((end_time - start_time))
# Convert elapsed time to minutes
elapsed_minutes=$((elapsed_time / 60))
echo "Environment $env_prefix created and all packages installed successfully in $elapsed_minutes minutes!"

View file

@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import argparse
import os
import subprocess
import time
import requests
from huggingface_hub import snapshot_download
def run_command(command):
print(f"Running: {command}")
subprocess.run(command, shell=True, check=True)
def download_dataset(repo_id, local_dir, allow_patterns):
print(f"Downloading dataset from {repo_id}...")
max_retries = 5
retry_delay = 10 # seconds
for attempt in range(max_retries):
try:
snapshot_download(
repo_id,
repo_type="dataset",
local_dir=local_dir,
allow_patterns=allow_patterns,
resume_download=True,
max_workers=16, # Don't hesitate to increase this number to lower the download time
)
break
except requests.exceptions.ReadTimeout:
if attempt < max_retries - 1:
print(f"Timeout occurred. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
raise
print(f"Dataset downloaded to {local_dir}")
def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
from datatrove.executor import LocalPipelineExecutor
from datatrove.pipeline.readers import ParquetReader
from datatrove.pipeline.writers import JsonlWriter
pipeline_exec = LocalPipelineExecutor(
pipeline=[
ParquetReader(
src_dir,
file_progress=True,
doc_progress=True,
glob_pattern="**/*.parquet",
),
JsonlWriter(
tgt_dir,
output_filename=dataset + ".chunk.${rank}.jsonl",
compression=None,
),
],
tasks=ntasks,
logging_dir=os.path.join(work_dir, "datatrove"),
)
pipeline_exec.run()
def setup_terashuf(work_dir):
terashuf_dir = os.path.join(work_dir, "terashuf")
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
if os.path.exists(terashuf_executable):
print("terashuf executable already exists. Skipping setup.")
return terashuf_dir
print("Setting up terashuf...")
run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}")
run_command(f"make -C {terashuf_dir}")
return terashuf_dir
def main(dataset, memory, data_dir, seed=42, nchunks=32):
# Configuration
repo_id = {
"fineweb_edu": "HuggingFaceFW/fineweb-edu",
"fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu",
"dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0",
"dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0",
}[dataset]
src_dir = f"{data_dir}/{dataset}"
out_dir = f"{src_dir}_shuffled"
os.makedirs(out_dir, exist_ok=True)
work_dir = src_dir # Directory of this Python file
prefix = f"{dataset}.chunk."
orig_extension = {
"fineweb_edu": ".jsonl",
"fineweb_edu_10bt": ".jsonl",
"dclm_baseline_1.0": ".jsonl.zst",
"dclm_baseline_1.0_10prct": ".jsonl.zst",
}[dataset]
cat_command = {
"fineweb_edu": "cat",
"fineweb_edu_10bt": "cat",
"dclm_baseline_1.0": "zstdcat",
"dclm_baseline_1.0_10prct": "zstdcat",
}[dataset]
allow_patterns = {
"fineweb_edu": None,
"fineweb_edu_10bt": "sample/10BT/*",
"dclm_baseline_1.0": "*.jsonl.zst",
"dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst",
}[dataset]
suffix = ".jsonl"
k_validation = 10000 # Number of lines to take from each chunk for validation
# Setup terashuf
terashuf_dir = setup_terashuf(work_dir)
# Download dataset
download_dataset(repo_id, src_dir, allow_patterns)
if "fineweb" in dataset:
parquet_to_jsonl(dataset, work_dir, src_dir, src_dir)
# Set up environment variables
os.environ["MEMORY"] = f"{memory}"
os.environ["SEED"] = f"{seed}"
# Run the original shuffling and splitting command
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
run_command(
f"ulimit -n 100000 && "
f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | "
f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}"
"; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;"
)
# Create validation set and remove lines from chunks
validation_file = f"{out_dir}/{dataset}.val{suffix}"
for i in range(nchunks):
chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}"
run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}")
run_command(f"sed -i '1,{k_validation}d' {chunk_file}")
print("All tasks completed successfully!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("dataset", type=str)
parser.add_argument("memory", type=float, default=8)
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nchunks", type=int, default=32)
args = parser.parse_args()
main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks)

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import argparse
import os
from typing import Optional
from requests.exceptions import HTTPError
TOKENIZER = {
"llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"),
"llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"),
"gemma": ("google/gemma-2-9b", "tokenizer.model"),
}
def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None):
if tokenizer_name in TOKENIZER:
repo_id, filename = TOKENIZER[tokenizer_name]
from huggingface_hub import hf_hub_download
try:
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=path_to_save,
local_dir_use_symlinks=False,
token=api_key if api_key else None,
)
except HTTPError as e:
if e.response.status_code == 401:
print(
"You need to pass a valid `--hf_token=...` to download private checkpoints."
)
else:
raise e
else:
from tiktoken import get_encoding
if "TIKTOKEN_CACHE_DIR" not in os.environ:
os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save
try:
get_encoding(tokenizer_name)
except ValueError:
print(
f"Tokenizer {tokenizer_name} not found. Please check the name and try again."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("tokenizer_name", type=str)
parser.add_argument("tokenizer_dir", type=str, default=8)
parser.add_argument("--api_key", type=str, default="")
args = parser.parse_args()
main(
tokenizer_name=args.tokenizer_name,
path_to_save=args.tokenizer_dir,
api_key=args.api_key,
)