# Copyright (c) Meta Platforms, Inc. and affiliates. import logging import os from typing import Any import numpy as np import yaml from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, find_and_sanitize_chunks, ) from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator from bytelatent.data.iterators.sampling_iterator import SamplingIterator from bytelatent.data.iterators.sequence_iterator import ( SequenceIterator, SequencePackingArgs, ) from bytelatent.data.patcher import PatcherArgs from bytelatent.distributed import DistributedArgs, EnvironmentArgs from bytelatent.metrics import LoggingArgs from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state def distribute_data_to_rank( *, dataset_path: str, preprocess_dir: str, entropy_model_name: str | None, arrow_batch_size: int, rank: int, world_size: int, s3_profile: str | None = None, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( dataset_path, world_size, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: for worker_id in range(n_workers_per_chunk): rank_to_arrow_iterator_params.append( ArrowFileIterator( file_path=chunk_path, worker_id=worker_id, num_workers=n_workers_per_chunk, preprocess_dir=preprocess_dir, dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 seq_len: int = 2048 seed: int = 42 add_bos: bool = True add_eos: bool = True load_async: bool = True prefetch_size: int = 64 preprocess_dir: str | None = None dataset_files: list[str] | None = None entropy_model_name: str | None = "transformer_100m" arrow_batch_size: int = 100 buffer_size: int = 64 pad_to_max_length: bool = True max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False tokenizer_args: TokenizerArgs = TokenizerArgs() patcher_args: PatcherArgs = PatcherArgs() def _create_sequence_iterators( self, rank: int, world_size: int ) -> dict[str, SequenceIterator]: sequence_packing_args = SequencePackingArgs( output_seq_len=self.seq_len, buffer_size=self.buffer_size, ) source_to_sequence_iterator: dict[str, SequenceIterator] = {} for dataset_path in self.sources: shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size) arrow_iterator = distribute_data_to_rank( dataset_path=os.path.join(self.root_dir, dataset_path), preprocess_dir=self.preprocess_dir, entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( looping_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, ) sequence_iterator = SequenceIterator( preprocess_iterator, sequence_packing_args=sequence_packing_args, rng_state=shuffle_rng_state, ) source_to_sequence_iterator[dataset_path] = sequence_iterator return source_to_sequence_iterator def build_from_rank( self, rank: int, world_size: int ) -> StatefulIterator[Batch, Any]: source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size) weight_rng_state = get_rng_state(self.seed + 1, rank, world_size) sampling_iterator = SamplingIterator( rng_state=weight_rng_state, source_to_weight=self.sources, source_to_iterator=source_to_sequence_iterators, ) tokenizer = self.tokenizer_args.build() packing_args = PackingArgs( batch_size=self.batch_size, seq_len=self.seq_len, pad_id=tokenizer.boe_id, max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) mp_iterator = MultiprocessIterator( packing_iterator, n_batches_to_prefetch=self.prefetch_size ) return mp_iterator class TrainArgs(BaseModel): model_config = ConfigDict(extra="forbid") name: str = "lingua" dump_dir: str = "" seed: int = 42 debug_dynamo: bool = False # Number of gradient accumulation steps # Total batch size is batch_size*grad_acc_steps grad_acc_steps: int = 1 gc_collect_freq: int = 1000 probe_freq: int | None = None # Nb optimizer steps to take steps: int = 1000 data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() # This is only needed for training the entropy model entropy_model: LMTransformerArgs | None = None # Instead of training main model, train entropy model train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() checkpoint: CheckpointArgs = CheckpointArgs() profiling: ProfilerArgs = ProfilerArgs() logging: LoggingArgs = LoggingArgs() # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: int | None = None eval: Any | None = None eval_on_gpus: int | None = None def dump_to_yaml_file( self, path: str, log_config: bool = True, sort_keys: bool = True ): model_dict = self.model_dump(mode="json") yaml_str = yaml.dump( model_dict, allow_unicode=True, sort_keys=sort_keys, default_flow_style=False, ) with open(path, "w") as f: if log_config: logger.info("Using the following config for this run:") logger.info(yaml_str) f.write(yaml_str)