From 7044771a120b2cae7f6bb1315330a9f2491246ee Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 27 Jan 2025 16:56:42 -0800 Subject: [PATCH] This includes fixes that make checkpointing and reloading work correctly. (#35) It also batches in a first set of changes for fixing eval code Summary: Test Plan: --- apps/main/lingua_train.py | 2 +- bytelatent/args.py | 87 ++++++++- bytelatent/checkpoint.py | 9 +- bytelatent/configs/debug.yaml | 9 +- bytelatent/configs/entropy_model.yaml | 9 +- bytelatent/data/data_types.py | 10 - .../data/iterators/multiprocess_iterator.py | 15 ++ {apps/main => bytelatent}/eval.py | 179 ++++-------------- {apps/main => bytelatent}/generate.py | 57 +++--- bytelatent/train.py | 81 ++++---- 10 files changed, 221 insertions(+), 237 deletions(-) rename {apps/main => bytelatent}/eval.py (56%) rename {apps/main => bytelatent}/generate.py (91%) diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py index bdb47da..7925ec6 100644 --- a/apps/main/lingua_train.py +++ b/apps/main/lingua_train.py @@ -544,7 +544,7 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval eval_args = dataclass_from_dict(EvalArgs, args.eval) diff --git a/bytelatent/args.py b/bytelatent/args.py index 56de22d..d1bac46 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,6 +5,7 @@ from typing import Any import numpy as np import yaml +from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state +def parse_args(args_cls): + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.create(args_cls().model_dump()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + pydantic_args = args_cls.model_validate(cfg) + return pydantic_args + + def distribute_data_to_rank( *, dataset_path: str, @@ -71,6 +85,22 @@ def distribute_data_to_rank( return rank_to_arrow_iterator_params[rank] +class PackedCausalTransformerGeneratorArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + temperature: float = 0.0 + top_p: float | None = None + top_k: float | None = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: int | None = None + until: list[str] = [] + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: str | None = "bf16" + device: str | None = "cuda" + + class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") s3_profile: str | None = None @@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel): return packing_iterator +class LMHarnessArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + tasks: list[Any] | None = None + num_fewshot: int | None = None + device: str | None = None + use_cache: str | None = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: int | float | None = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: str | None = None + apply_chat_template: bool | str = False + fewshot_as_multiturn: bool = False + gen_kwargs: str | None = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +class ValidationArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + max_steps: int | None = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) + use_val_from_train_src: bool = True # Use the validation set from training sources + root_dir: str = "" + sources: list[str] = [] # Other sources to eval on + + +class EvalArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dump_dir: str + ckpt_dir: str + metric_log_dir: str | None = None + generator: PackedCausalTransformerGeneratorArgs = ( + PackedCausalTransformerGeneratorArgs() + ) + + harness: LMHarnessArgs | None = LMHarnessArgs() + validation: ValidationArgs | None = ValidationArgs() + + global_step: int | None = None # for in-training evaluation + s3_profile: str | None = None + + class TrainArgs(BaseModel): model_config = ConfigDict(extra="forbid") name: str = "lingua" @@ -186,6 +268,9 @@ class TrainArgs(BaseModel): # Nb optimizer steps to take steps: int = 1000 + # If not None, halt training after this many steps, + # useful for debugging + max_steps: int | None = None data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() @@ -203,7 +288,7 @@ class TrainArgs(BaseModel): # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: int | None = None - eval: Any | None = None + eval: EvalArgs | None = None eval_on_gpus: int | None = None def dump_to_yaml_file( diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index bcf591e..f213c84 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -7,6 +7,7 @@ import re from pathlib import Path from typing import List, Optional, Tuple +import fsspec import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import ( set_state_dict, ) +from bytelatent.data.file_util import get_fs from bytelatent.distributed import get_is_master logger = logging.getLogger("CHECKPOINT") @@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel): path: str | None = None init_ckpt_path: str | None = None continue_training_from_init: bool = False + s3_profile: str | None = None def _get_key_step(name: str): return int(re.findall(RE_DIGITS, name)[-1]) -def consolidate_checkpoints(ckpt_dir: str): +def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): """ Consolidates all FSDP checkpoints in a directory to a single file Consolidate checkpoint is saved in a subdirectory of ckpt_dir @@ -102,15 +105,17 @@ def load_from_checkpoint( dcp.load(state_dict, checkpoint_id=ckpt_dir) +# TODO: Rewrite the file operations here to use fsspec to enable s3 writing. class CheckpointManager: def __init__(self, args: CheckpointArgs): self.path = args.path + self.fs = get_fs(self.path, s3_profile=args.s3_profile) self.dump_every = args.dump self.eval_every = args.eval self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert os.path.exists( + assert self.fs.exists( self.path ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 1098ff5..07d489f 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -98,11 +98,4 @@ logging: freq: 10 eval_on_gpus: 8 -eval: - dataset_dir: /checkpoint/amaia/codegen/datasets/eval - tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu - generator: - max_tokens: 65536 - dtype: bf16 - - mp_size: 1 +eval: null diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index 51b65d4..d7c27b7 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -72,11 +72,4 @@ logging: freq: 10 eval_on_gpus: 8 -eval: - dataset_dir: ??? - tasks: ??? - generator: - max_tokens: 65536 - dtype: bf16 - - mp_size: 1 +eval: null diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index aa2daa9..f4bbc07 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel): n_views: int = 2 -class DataLoaderState(BaseModel): - model_config = ConfigDict(extra="forbid") - multi_choice_state: MultiChoiceState - pack_tokens_state: BltPackTokensState - prefetch_state: PrefetchState - - -BltIterator = Iterator[tuple[BltExample, DataLoaderState]] - - class BltSequence(BaseModel): tokens: list[int] mask: list[bool] diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index f17ca6e..49d99ac 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator): self.producer = None self.stop_iterating_event = None self.state_dumped_event = None + self.force_shutdown = False + + def shutdown(self): + if self.producer is not None: + # This properly shuts things down + self.producer.kill() + self.force_shutdown = True def get_state(self) -> MultiprocessIteratorState: """ @@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator): to halt the background process and allow it to write the state to the main loop in order to not lose data """ + if self.force_shutdown: + raise ValueError( + "State will be invalid if shutdown was forced before state persisted." + ) if self.producer is None: serialized_prefetch_buffer = json.dumps( [b.to_python_dict() for b in self.prefetch_buffer] @@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator): ) def create_iter(self): + if self.force_shutdown: + raise ValueError( + "Iterator may be invalid if shutdown was forced before state persisted." + ) logging.info("Main thread: Creating MP iterator") # First yield from the stored prefetch buffer. if self.prefetch_buffer is not None: diff --git a/apps/main/eval.py b/bytelatent/eval.py similarity index 56% rename from apps/main/eval.py rename to bytelatent/eval.py index ed20f49..ae73066 100644 --- a/apps/main/eval.py +++ b/bytelatent/eval.py @@ -4,20 +4,20 @@ import json import logging import os from collections import defaultdict -from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any import torch -from lingua.args import dump_config -from lingua.data import init_choice_state, setup_sources from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from omegaconf import OmegaConf +from pydantic import BaseModel, ConfigDict +from bytelatent.args import EvalArgs, ValidationArgs, parse_args from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -25,72 +25,17 @@ from bytelatent.distributed import ( get_world_size, setup_torch_distributed, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs - -from apps.main.generate import ( +from bytelatent.generate import ( PackedCausalTransformerGenerator, - PackedCausalTransformerGeneratorArgs, load_consolidated_model_and_tokenizer, ) +from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" logger = logging.getLogger() -@dataclass -class LMHarnessArgs: - tasks: Optional[List[Any]] = None - num_fewshot: Optional[int] = None - device: Optional[str] = None - use_cache: Optional[str] = None - cache_requests: bool = False - rewrite_requests_cache: bool = False - delete_requests_cache: bool = False - limit: Optional[Union[int, float]] = None - bootstrap_iters: int = 100000 - check_integrity: bool = False - write_out: bool = False - log_samples: bool = True - system_instruction: Optional[str] = None - apply_chat_template: Union[bool, str] = False - fewshot_as_multiturn: bool = False - gen_kwargs: Optional[str] = None - verbosity: str = "INFO" - predict_only: bool = False - random_seed: int = 0 - numpy_random_seed: int = 1234 - torch_random_seed: int = 1234 - fewshot_random_seed: int = 1234 - - -@dataclass -class ValidationArgs: - max_steps: Optional[int] = ( - None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) - ) - use_val_from_train_src: bool = True # Use the validation set from training sources - root_dir: str = "" - sources: List[str] = field(default_factory=list) # Other sources to eval on - - -@dataclass -class EvalArgs: - name: str = "evals" - dump_dir: Optional[str] = None - metric_log_dir: Optional[str] = None - ckpt_dir: str = "" - generator: PackedCausalTransformerGeneratorArgs = field( - default_factory=PackedCausalTransformerGeneratorArgs - ) - harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) - validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs) - - wandb: Optional[Any] = None - - global_step: Optional[int] = None # for in-training evaluation - - def all_dicts_same(dict_list): if not dict_list: # Check if the list is empty return True @@ -120,7 +65,7 @@ class EvalHarnessLM(LM): self._world_size = get_world_size() self.device = generator.device - def generate_until(self, requests: List[Instance]) -> List[str]: + def generate_until(self, requests: list[Instance]) -> list[str]: prompts, gen_args = zip(*[req.args for req in requests]) assert all_dicts_same(gen_args), "Doesn't support different gen args for now" gen_args = gen_args[0] @@ -141,7 +86,7 @@ class EvalHarnessLM(LM): filtered_gen.append(g) return filtered_gen - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: prompts, continuations = zip(*[req.args for req in requests]) inputs = [req.args[0] + req.args[1] for req in requests] max_gen_len = self.generator.max_gen_len @@ -158,7 +103,7 @@ class EvalHarnessLM(LM): self.generator.max_gen_len = max_gen_len return results - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: prompts = [req.args[0] for req in requests] max_gen_len = self.generator.max_gen_len # We temporarily lower max gen len @@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): return all_val_metrics -def launch_eval(cfg: EvalArgs): +def launch_eval(eval_args: EvalArgs): if not torch.distributed.is_initialized(): setup_torch_distributed(DistributedArgs()) + + fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) if ( - Path(cfg.ckpt_dir).exists() - and (Path(cfg.ckpt_dir) / "params.json").exists() - and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + fs.exists(eval_args.ckpt_dir) + and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) + and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 ): - consolidate_path = Path(cfg.ckpt_dir) + consolidate_path = eval_args.ckpt_dir else: - consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER - if not consolidate_path.exists() and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) + if not fs.exists(consolidate_path) and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) - Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) - dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + fs.mkdirs(eval_args.dump_dir, exist_ok=True) + with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: + f.write(eval_args.model_dump_json()) - consolidate_path = str(consolidate_path) torch.distributed.barrier() logger.info("Loading model") + # TODO: Make this general so that it works with either + # LMTransformer or Blt, similar with args model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ) logger.info("Model loaded") model.eval() - generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) wrap = EvalHarnessLM(generator) - results = simple_evaluate(wrap, **asdict(cfg.harness)) + # Redo + results = simple_evaluate(wrap, eval_args.harness.model_dump()) val_results = None - if cfg.validation: - val_results = eval_on_val(generator, cfg.validation, train_cfg) + if eval_args.validation: + val_results = eval_on_val(generator, eval_args.validation, train_cfg) if get_global_rank() == 0: - with open(Path(cfg.dump_dir) / "results.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) logger.info(f"All evaluation results: {results['results']}") if val_results is not None: - with open(Path(cfg.dump_dir) / "validation.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") - if cfg.metric_log_dir and get_global_rank() == 0: - metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + if eval_args.metric_log_dir and get_global_rank() == 0: + metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") logger.info(f"Writing metric logs to {metric_log_path}") timestamp = { "created_at": datetime.utcnow().isoformat(), } - if cfg.global_step is not None: - timestamp["global_step"] = cfg.global_step + if eval_args.global_step is not None: + timestamp["global_step"] = eval_args.global_step print( json.dumps(timestamp | results["results"]), - file=open(metric_log_path, mode="a"), + file=fs.open(metric_log_path, mode="a"), flush=True, ) - val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" + val_log_path = os.path.join( + eval_args.metric_log_dir, "metrics.validation.jsonl" + ) if val_results is not None: print( json.dumps(timestamp | val_results), - file=open(val_log_path, mode="a"), + file=fs.open(val_log_path, mode="a"), flush=True, ) @@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs): def main(): - """ - The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments - This accepts arguments as a dot list - So if the dataclass looks like - - @dataclass - class DummyArgs: - name: str - model: LMTransformerArgsgs - - @dataclass - class LMTransformerArgsgs: - dim: int - - Then you can pass model.dim=32 to change values in LMTransformerArgsgs - or just name=tictac for top level attributes. - - The behavior here is as follows: - 1. We instantiate EvalArgs with its default values - 2. We override those default values with the ones in the provided config file - 3. We override the result with the additional arguments provided through command line - - For example, if the config is the following - - model: - dim: 128 - n_layers: 4 - - and you call eval.py with eval.py model.dim=64 - - Then the final TrainArgs will have - - model: - dim: 64 - n_layers: 4 - - Plus all the default values in EvalArgs dataclass. - """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.structured(EvalArgs()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_object(cfg) - launch_eval(cfg) + eval_args = parse_args(EvalArgs) + launch_eval(eval_args) if __name__ == "__main__": diff --git a/apps/main/generate.py b/bytelatent/generate.py similarity index 91% rename from apps/main/generate.py rename to bytelatent/generate.py index a3a8627..eb79d81 100644 --- a/apps/main/generate.py +++ b/bytelatent/generate.py @@ -1,20 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Optional import torch -from lingua.args import dataclass_from_dict -from lingua.tokenizers.abstract_tokenizer import Tokenizer -from lingua.tokenizers.build_tokenizer import build_tokenizer from omegaconf import OmegaConf from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import create_block_mask from tqdm import tqdm +from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( Attention, causal_mask, @@ -23,7 +19,10 @@ from bytelatent.base_transformer import ( lengths_to_start_ids, ) from bytelatent.checkpoint import CONSOLIDATE_NAME -from bytelatent.transformer import LMTransformer, LMTransformerArgs +from bytelatent.data.file_util import get_fs +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.transformer import LMTransformer def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): return next_token.view(shape[:-1]) -def pack_prompts(prompts: List[int]): +def pack_prompts(prompts: list[int]): res = [] lengths = [] for i, p in enumerate(prompts): @@ -120,22 +119,6 @@ class KVCache(nn.Module): return self.k_cache, self.v_cache -@dataclass -class PackedCausalTransformerGeneratorArgs: - temperature: float = 0.0 - top_p: Optional[float] = None - top_k: Optional[float] = None - max_gen_len: int = 512 # Maximum number of tokens to generate - max_tokens: int = 1024 # Maximum number of tokens that can go through the model - max_prompt_len: Optional[int] = None - until: List[str] = field(default_factory=list) - compile_prefilling: bool = False - reduce_generation_overhead: bool = False - show_progress: bool = False - dtype: Optional[str] = "bf16" - device: Optional[str] = "cuda" - - class PackedCausalTransformerGenerator: def __init__( self, @@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator: def load_consolidated_model_and_tokenizer( consolidated_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ): - ckpt_path = Path(consolidated_path) - config = ckpt_path / "params.json" - config = OmegaConf.load(config) + train_args_path = os.path.join(consolidated_path, "params.json") + fs = get_fs(train_args_path) + with fs.open(train_args_path) as f: + train_args = TrainArgs.model_validate_json(f.read()) + + if train_args.train_entropy_model: + model_args = train_args.entropy_model + model = LMTransformer(model_args) + else: + model_args = train_args.model + model = ByteLatentTransformer(model_args) param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ - config.distributed.model_dtype + train_args.distributed.model_dtype ] - model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) - tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) - model = model_cls(model_args) - st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + tokenizer = train_args.data.tokenizer_args.build() + st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): param.data = param.data.to(dtype=param_dtype) - return model, tokenizer, config + return model, tokenizer, train_args def main(): diff --git a/bytelatent/train.py b/bytelatent/train.py index 1d0fa40..6b20ecd 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -10,7 +10,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path from timeit import default_timer as timer -from typing import Any, Dict, Type, TypeVar +from typing import Any, TypeVar import torch import torch.distributed @@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs +from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint -from bytelatent.data.data_types import DataLoaderState +from bytelatent.data.iterators.multiprocess_iterator import ( + MultiprocessIterator, + MultiprocessIteratorState, +) +from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, @@ -39,6 +43,7 @@ from bytelatent.distributed import ( setup_env, setup_torch_distributed, ) +from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer @@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"): return dict(items) -def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: - """ - Converts a dictionary to a dataclass instance, recursively for nested structures. - """ - base = OmegaConf.structured(cls()) - OmegaConf.set_struct(base, strict) - override = OmegaConf.create(data) - return OmegaConf.to_object(OmegaConf.merge(base, override)) +def get_iterator_state_name(iterator_state): + if isinstance(iterator_state, MultiprocessIteratorState): + return "multiprocess" + elif isinstance(iterator_state, PackingIteratorState): + return "packing" + else: + raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +# TODO: Make this pydantic based instead of data class based +# TODO: Generalize this to any iterator state @dataclass class TrainState(Stateful): step: int # Nb of steps taken by the optimizer acc_step: int # Nb of accumulation steps done since last optimizer step scheduler: lr_scheduler.LambdaLR - data_loader_state: DataLoaderState + data_loader_state: MultiprocessIteratorState | PackingIteratorState scale: float = 1.0 + data_loader_class: str | None = None - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "step": self.step, "acc_step": self.acc_step, - "data_loader_state": self.data_loader_state.dict(), + "data_loader_state": self.data_loader_state.model_dump(), + "data_loader_class": get_iterator_state_name(self.data_loader_state), "scheduler": self.scheduler.state_dict(), } def load_state_dict(self, state_dict): self.step = state_dict["step"] self.acc_step = state_dict["acc_step"] - self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) + self.data_loader_class = state_dict["data_loader_class"] + if self.data_loader_class == "multiprocess": + self.data_loader_state = MultiprocessIteratorState( + **state_dict["data_loader_state"] + ) + elif self.data_loader_class == "packing": + self.data_loader_state = PackingIteratorState( + **state_dict["data_loader_state"] + ) + else: + raise ValueError(f"invalid data loader class: {self.data_loader_class}") self.scheduler.load_state_dict(state_dict["scheduler"]) @@ -345,7 +363,10 @@ def train(args: TrainArgs): nwords_since_last_log = 0 time_last_log = timer() gc.collect() - while train_state.step < args.steps: + saved = False + while train_state.step < args.steps and ( + args.max_steps is None or train_state.step < args.max_steps + ): # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 train_state.acc_step += 1 train_state.acc_step = train_state.acc_step % args.grad_acc_steps @@ -552,7 +573,6 @@ def train(args: TrainArgs): f" pow: {gpu_mem_stats.power_draw/1000} W" ) - saved = False if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): @@ -567,18 +587,14 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval - - eval_args = dataclass_from_dict(EvalArgs, args.eval) + eval_args = args.eval eval_args.global_step = train_state.step eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) - eval_args.dump_dir = str( - os.path.join( - args.dump_dir, - "evals", - EVAL_FOLDER_NAME.format(train_state.step), - ) + eval_args.dump_dir = os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), ) eval_args.metric_log_dir = args.dump_dir if args.async_eval_gpus is None: @@ -619,6 +635,9 @@ def train(args: TrainArgs): args, device_mesh=world_mesh, ) + if isinstance(data_loader, MultiprocessIterator): + logger.info("Closing MP iterator before exiting") + data_loader.shutdown() gc.collect() @@ -661,15 +680,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(TrainArgs().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - train_args = TrainArgs.model_validate(cfg) + train_args = parse_args(TrainArgs) if train_args.debug_dynamo: import torch._dynamo