# Copyright (c) Meta Platforms, Inc. and affiliates. import logging import os from typing import Any import numpy as np import yaml from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, find_and_sanitize_chunks, ) from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator from bytelatent.data.iterators.sampling_iterator import SamplingIterator from bytelatent.data.iterators.sequence_iterator import ( SequenceIterator, SequencePackingArgs, ) from bytelatent.data.patcher import PatcherArgs from bytelatent.distributed import DistributedArgs, EnvironmentArgs from bytelatent.metrics import LoggingArgs from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state def parse_args(args_cls): cli_args = OmegaConf.from_cli() file_cfg = OmegaConf.load(cli_args.config) # We remove 'config' attribute from config as the underlying DataClass does not have it del cli_args.config default_cfg = OmegaConf.create(args_cls().model_dump()) cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) pydantic_args = args_cls.model_validate(cfg) return pydantic_args def distribute_data_to_rank( *, dataset_path: str, preprocess_dir: str, entropy_model_name: str | None, arrow_batch_size: int, rank: int, world_size: int, s3_profile: str | None = None, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( dataset_path, world_size, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: for worker_id in range(n_workers_per_chunk): rank_to_arrow_iterator_params.append( ArrowFileIterator( file_path=chunk_path, worker_id=worker_id, num_workers=n_workers_per_chunk, preprocess_dir=preprocess_dir, dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] class PackedCausalTransformerGeneratorArgs(BaseModel): model_config = ConfigDict(extra="forbid") temperature: float = 0.0 top_p: float | None = None top_k: float | None = None max_gen_len: int = 512 # Maximum number of tokens to generate max_tokens: int = 1024 # Maximum number of tokens that can go through the model max_prompt_len: int | None = None until: list[str] = [] compile_prefilling: bool = False reduce_generation_overhead: bool = False show_progress: bool = False dtype: str | None = "bf16" device: str | None = "cuda" class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 seq_len: int = 2048 seed: int = 42 add_bos: bool = True add_eos: bool = True load_async: bool = True prefetch_size: int = 64 preprocess_dir: str | None = None dataset_files: list[str] | None = None entropy_model_name: str | None = "transformer_100m" arrow_batch_size: int = 100 buffer_size: int = 64 pad_to_max_length: bool = True max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False add_patches: bool = True tokenizer_args: TokenizerArgs = TokenizerArgs() patcher_args: PatcherArgs = PatcherArgs() def _create_sequence_iterators( self, rank: int, world_size: int ) -> dict[str, SequenceIterator]: sequence_packing_args = SequencePackingArgs( output_seq_len=self.seq_len, buffer_size=self.buffer_size, ) source_to_sequence_iterator: dict[str, SequenceIterator] = {} for dataset_path in self.sources: shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size) arrow_iterator = distribute_data_to_rank( dataset_path=os.path.join(self.root_dir, dataset_path), preprocess_dir=self.preprocess_dir, entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( looping_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, add_patches=self.add_patches, ) sequence_iterator = SequenceIterator( preprocess_iterator, sequence_packing_args=sequence_packing_args, rng_state=shuffle_rng_state, ) source_to_sequence_iterator[dataset_path] = sequence_iterator return source_to_sequence_iterator def build_from_rank( self, rank: int, world_size: int ) -> StatefulIterator[Batch, Any]: source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size) weight_rng_state = get_rng_state(self.seed + 1, rank, world_size) sampling_iterator = SamplingIterator( rng_state=weight_rng_state, source_to_weight=self.sources, source_to_iterator=source_to_sequence_iterators, ) tokenizer = self.tokenizer_args.build() if self.tokenizer_args.name == "bytes": # TODO: Check this with Artidoro pad_id = 0 else: pad_id = tokenizer.boe_id packing_args = PackingArgs( batch_size=self.batch_size, seq_len=self.seq_len, pad_id=pad_id, max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, tokenizer_name=self.tokenizer_args.name, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: mp_iterator = MultiprocessIterator( packing_iterator, n_batches_to_prefetch=self.prefetch_size ) return mp_iterator else: return packing_iterator class LMHarnessArgs(BaseModel): model_config = ConfigDict(extra="forbid") tasks: list[Any] | None = None num_fewshot: int | None = None device: str | None = None use_cache: str | None = None cache_requests: bool = False rewrite_requests_cache: bool = False delete_requests_cache: bool = False limit: int | float | None = None bootstrap_iters: int = 100000 check_integrity: bool = False write_out: bool = False log_samples: bool = True system_instruction: str | None = None apply_chat_template: bool | str = False fewshot_as_multiturn: bool = False gen_kwargs: str | None = None verbosity: str = "INFO" predict_only: bool = False random_seed: int = 0 numpy_random_seed: int = 1234 torch_random_seed: int = 1234 fewshot_random_seed: int = 1234 class ValidationArgs(BaseModel): model_config = ConfigDict(extra="forbid") max_steps: int | None = ( None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) ) use_val_from_train_src: bool = True # Use the validation set from training sources root_dir: str = "" sources: list[str] = [] # Other sources to eval on class EvalArgs(BaseModel): model_config = ConfigDict(extra="forbid") dump_dir: str ckpt_dir: str metric_log_dir: str | None = None generator: PackedCausalTransformerGeneratorArgs = ( PackedCausalTransformerGeneratorArgs() ) harness: LMHarnessArgs | None = LMHarnessArgs() validation: ValidationArgs | None = ValidationArgs() global_step: int | None = None # for in-training evaluation s3_profile: str | None = None class TrainArgs(BaseModel): model_config = ConfigDict(extra="forbid") name: str = "lingua" dump_dir: str = "" seed: int = 42 debug_dynamo: bool = False # Number of gradient accumulation steps # Total batch size is batch_size*grad_acc_steps grad_acc_steps: int = 1 gc_collect_freq: int = 1000 probe_freq: int | None = None # Nb optimizer steps to take steps: int = 1000 # If not None, halt training after this many steps, # useful for debugging max_steps: int | None = None data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs() # This is only needed for training the entropy model entropy_model: LMTransformerArgs | None = None # Instead of training main model, train entropy model train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() checkpoint: CheckpointArgs = CheckpointArgs() profiling: ProfilerArgs = ProfilerArgs() logging: LoggingArgs = LoggingArgs() # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: int | None = None eval: EvalArgs | None = None eval_on_gpus: int | None = None def dump_to_yaml_file( self, path: str, log_config: bool = True, sort_keys: bool = True ): model_dict = self.model_dump(mode="json") yaml_str = yaml.dump( model_dict, allow_unicode=True, sort_keys=sort_keys, default_flow_style=False, ) with open(path, "w") as f: if log_config: logger.info("Using the following config for this run:") logger.info(yaml_str) f.write(yaml_str)