blt/bytelatent/args.py
Pedro Rodriguez 3117ac1f1f
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled
Make it possible to specify multiple config files
Summary:

Make it possible to specify multiple config files.
Parsing CLI is not a special case anymore, just uses the same config inheritance method.

Test Plan:

Test that this iterpolates in the right order via unit tests

Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is:

- Default pydantic args
- Included configs, eg `config`
- CLI args

```
python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null

```


Summary:

Test Plan:
2025-02-18 18:41:02 +00:00

325 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
from typing import Any
import numpy as np
import yaml
from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.iterators.sampling_iterator import SamplingIterator
from bytelatent.data.iterators.sequence_iterator import (
SequenceIterator,
SequencePackingArgs,
)
from bytelatent.data.patcher import PatcherArgs
from bytelatent.distributed import DistributedArgs, EnvironmentArgs
from bytelatent.metrics import LoggingArgs
from bytelatent.model.blt import ByteLatentTransformerArgs
from bytelatent.optim import OptimArgs
from bytelatent.profiling import ProfilerArgs
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformerArgs
logger = logging.getLogger()
def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks
def distribute_data_to_rank(
*,
dataset_path: str,
preprocess_dir: str,
entropy_model_name: str | None,
arrow_batch_size: int,
rank: int,
world_size: int,
s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, file_pattern, s3_profile=s3_profile
)
n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = []
for chunk_path in dataset_chunks:
for worker_id in range(n_workers_per_chunk):
rank_to_arrow_iterator_params.append(
ArrowFileIterator(
file_path=chunk_path,
worker_id=worker_id,
num_workers=n_workers_per_chunk,
preprocess_dir=preprocess_dir,
dataset_files=None,
entropy_model_name=entropy_model_name,
arrow_batch_size=arrow_batch_size,
s3_profile=s3_profile,
)
)
return rank_to_arrow_iterator_params[rank]
class PackedCausalTransformerGeneratorArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
temperature: float = 0.0
top_p: float | None = None
top_k: float | None = None
max_gen_len: int = 512 # Maximum number of tokens to generate
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
max_prompt_len: int | None = None
until: list[str] = []
compile_prefilling: bool = False
reduce_generation_overhead: bool = False
show_progress: bool = False
dtype: str | None = "bf16"
device: str | None = "cuda"
class DataloaderArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
s3_profile: str | None = None
root_dir: str | None = None
sources: dict[str, float] = {}
batch_size: int = 2
seq_len: int = 2048
seed: int = 42
add_bos: bool = True
add_eos: bool = True
load_async: bool = True
prefetch_size: int = 64
preprocess_dir: str | None = None
dataset_files: list[str] | None = None
entropy_model_name: str | None = "transformer_100m"
arrow_batch_size: int = 100
buffer_size: int = 64
pad_to_max_length: bool = True
max_encoder_seq_length: int = 12288
enable_byte_ngrams: bool = False
add_patches: bool = True
tokenizer_args: TokenizerArgs = TokenizerArgs()
patcher_args: PatcherArgs = PatcherArgs()
def _create_sequence_iterators(
self, rank: int, world_size: int
) -> dict[str, SequenceIterator]:
sequence_packing_args = SequencePackingArgs(
output_seq_len=self.seq_len,
buffer_size=self.buffer_size,
)
source_to_sequence_iterator: dict[str, SequenceIterator] = {}
for dataset_path in self.sources:
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
arrow_iterator = distribute_data_to_rank(
dataset_path=os.path.join(self.root_dir, dataset_path),
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
rank=rank,
world_size=world_size,
s3_profile=self.s3_profile,
)
looping_iterator = LoopingIterator(arrow_iterator)
preprocess_iterator = PreprocessIterator(
looping_iterator,
patcher_args=self.patcher_args,
tokenizer_args=self.tokenizer_args,
add_patches=self.add_patches,
)
sequence_iterator = SequenceIterator(
preprocess_iterator,
sequence_packing_args=sequence_packing_args,
rng_state=shuffle_rng_state,
)
source_to_sequence_iterator[dataset_path] = sequence_iterator
return source_to_sequence_iterator
def build_from_rank(
self, rank: int, world_size: int
) -> StatefulIterator[Batch, Any]:
source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size)
weight_rng_state = get_rng_state(self.seed + 1, rank, world_size)
sampling_iterator = SamplingIterator(
rng_state=weight_rng_state,
source_to_weight=self.sources,
source_to_iterator=source_to_sequence_iterators,
)
tokenizer = self.tokenizer_args.build()
if self.tokenizer_args.name == "bytes":
# TODO: Check this with Artidoro
pad_id = 0
else:
pad_id = tokenizer.boe_id
packing_args = PackingArgs(
batch_size=self.batch_size,
seq_len=self.seq_len,
pad_id=pad_id,
max_length=self.max_encoder_seq_length,
pad_to_max_length=self.pad_to_max_length,
enable_byte_ngrams=self.enable_byte_ngrams,
tokenizer_name=self.tokenizer_args.name,
)
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
if self.load_async:
mp_iterator = MultiprocessIterator(
packing_iterator, n_batches_to_prefetch=self.prefetch_size
)
return mp_iterator
else:
return packing_iterator
class LMHarnessArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
tasks: list[Any] | None = None
num_fewshot: int | None = None
device: str | None = None
use_cache: str | None = None
cache_requests: bool = False
rewrite_requests_cache: bool = False
delete_requests_cache: bool = False
limit: int | float | None = None
bootstrap_iters: int = 100000
check_integrity: bool = False
write_out: bool = False
log_samples: bool = True
system_instruction: str | None = None
apply_chat_template: bool | str = False
fewshot_as_multiturn: bool = False
gen_kwargs: str | None = None
verbosity: str = "INFO"
predict_only: bool = False
random_seed: int = 0
numpy_random_seed: int = 1234
torch_random_seed: int = 1234
fewshot_random_seed: int = 1234
class ValidationArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
max_steps: int | None = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = ""
sources: list[str] = [] # Other sources to eval on
class EvalArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dump_dir: str
ckpt_dir: str
metric_log_dir: str | None = None
generator: PackedCausalTransformerGeneratorArgs = (
PackedCausalTransformerGeneratorArgs()
)
harness: LMHarnessArgs | None = LMHarnessArgs()
validation: ValidationArgs | None = ValidationArgs()
global_step: int | None = None # for in-training evaluation
s3_profile: str | None = None
class TrainArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str = "lingua"
dump_dir: str = ""
seed: int = 42
debug_dynamo: bool = False
# Number of gradient accumulation steps
# Total batch size is batch_size*grad_acc_steps
grad_acc_steps: int = 1
gc_collect_freq: int = 1000
probe_freq: int | None = None
# Nb optimizer steps to take
steps: int = 1000
# If not None, halt training after this many steps,
# useful for debugging
max_steps: int | None = None
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs()
# This is only needed for training the entropy model
entropy_model: LMTransformerArgs | None = None
# Instead of training main model, train entropy model
train_entropy_model: bool = False
distributed: DistributedArgs = DistributedArgs()
env: EnvironmentArgs = EnvironmentArgs()
checkpoint: CheckpointArgs = CheckpointArgs()
profiling: ProfilerArgs = ProfilerArgs()
logging: LoggingArgs = LoggingArgs()
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
async_eval_gpus: int | None = None
eval: EvalArgs | None = None
eval_on_gpus: int | None = None
def dump_to_yaml_file(
self, path: str, log_config: bool = True, sort_keys: bool = True
):
yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys)
with open(path, "w") as f:
if log_config:
logger.info("Using the following config for this run:")
logger.info(yaml_str)
f.write(yaml_str)
def dump_to_yaml_str(self, sort_keys: bool = True):
model_dict = self.model_dump(mode="json")
yaml_str = yaml.dump(
model_dict,
allow_unicode=True,
sort_keys=sort_keys,
default_flow_style=False,
)
return yaml_str