From 936d9437be43c24bfe4718388f6c28d8d50427cd Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 09:57:22 -0800 Subject: [PATCH 1/2] Allow ArrowIterator to read from json (#45) Summary: Currently, arrow iterator can only read arrow files. However, the pyarrow library can read other formats, including jsonlines. This allows the same ArrowIterator to read from jsonlines, so we can read from the original source data, and simply omit the entropy column when doing so Test Plan: Run train script until dataloader starts --- bytelatent/args.py | 38 +++++++- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++++-------- bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/stool.py | 2 +- 4 files changed, 105 insertions(+), 58 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index fc72b32..263e8e3 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,33 @@ def parse_args(args_cls): return pydantic_args +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +89,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/stool.py b/bytelatent/stool.py index b156177..b47ddc7 100644 --- a/bytelatent/stool.py +++ b/bytelatent/stool.py @@ -4,10 +4,10 @@ import json import os import shutil import subprocess -from pydantic import BaseModel from typing import Any, Dict from omegaconf import OmegaConf +from pydantic import BaseModel class StoolArgs(BaseModel): From 4e2ed0aa050b7d9ec610c69cb7010d59586fcf8d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 18:08:01 +0000 Subject: [PATCH 2/2] Add bpb and n_bytes to metric logging Summary: Test Plan: --- bytelatent/distributed.py | 16 +++++++++++--- bytelatent/train.py | 46 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..b3f31a6 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" diff --git a/bytelatent/train.py b/bytelatent/train.py index 9bfe12a..5b31e0c 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -392,6 +396,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -413,6 +420,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -487,7 +512,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -498,6 +523,10 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + world_size = get_world_size() if 1 < world_size <= 8: # For some reason, there are errors in reduces due to @@ -598,20 +627,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):