This includes fixes that make checkpointing and reloading work correctly. (#35)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

It also batches in a first set of changes for fixing eval code

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-27 16:56:42 -08:00 committed by GitHub
parent 7622d28b74
commit 7044771a12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 221 additions and 237 deletions

View file

@ -544,7 +544,7 @@ def train(args: TrainArgs):
if args.eval is not None and every_n_steps(
train_state, args.checkpoint.eval.every, acc_step=0
):
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
eval_args = dataclass_from_dict(EvalArgs, args.eval)

View file

@ -5,6 +5,7 @@ from typing import Any
import numpy as np
import yaml
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs
@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
def parse_args(args_cls):
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.create(args_cls().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
pydantic_args = args_cls.model_validate(cfg)
return pydantic_args
def distribute_data_to_rank(
*,
dataset_path: str,
@ -71,6 +85,22 @@ def distribute_data_to_rank(
return rank_to_arrow_iterator_params[rank]
class PackedCausalTransformerGeneratorArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
temperature: float = 0.0
top_p: float | None = None
top_k: float | None = None
max_gen_len: int = 512 # Maximum number of tokens to generate
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
max_prompt_len: int | None = None
until: list[str] = []
compile_prefilling: bool = False
reduce_generation_overhead: bool = False
show_progress: bool = False
dtype: str | None = "bf16"
device: str | None = "cuda"
class DataloaderArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
s3_profile: str | None = None
@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel):
return packing_iterator
class LMHarnessArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
tasks: list[Any] | None = None
num_fewshot: int | None = None
device: str | None = None
use_cache: str | None = None
cache_requests: bool = False
rewrite_requests_cache: bool = False
delete_requests_cache: bool = False
limit: int | float | None = None
bootstrap_iters: int = 100000
check_integrity: bool = False
write_out: bool = False
log_samples: bool = True
system_instruction: str | None = None
apply_chat_template: bool | str = False
fewshot_as_multiturn: bool = False
gen_kwargs: str | None = None
verbosity: str = "INFO"
predict_only: bool = False
random_seed: int = 0
numpy_random_seed: int = 1234
torch_random_seed: int = 1234
fewshot_random_seed: int = 1234
class ValidationArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
max_steps: int | None = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = ""
sources: list[str] = [] # Other sources to eval on
class EvalArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dump_dir: str
ckpt_dir: str
metric_log_dir: str | None = None
generator: PackedCausalTransformerGeneratorArgs = (
PackedCausalTransformerGeneratorArgs()
)
harness: LMHarnessArgs | None = LMHarnessArgs()
validation: ValidationArgs | None = ValidationArgs()
global_step: int | None = None # for in-training evaluation
s3_profile: str | None = None
class TrainArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str = "lingua"
@ -186,6 +268,9 @@ class TrainArgs(BaseModel):
# Nb optimizer steps to take
steps: int = 1000
# If not None, halt training after this many steps,
# useful for debugging
max_steps: int | None = None
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
@ -203,7 +288,7 @@ class TrainArgs(BaseModel):
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
async_eval_gpus: int | None = None
eval: Any | None = None
eval: EvalArgs | None = None
eval_on_gpus: int | None = None
def dump_to_yaml_file(

View file

@ -7,6 +7,7 @@ import re
from pathlib import Path
from typing import List, Optional, Tuple
import fsspec
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import (
set_state_dict,
)
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import get_is_master
logger = logging.getLogger("CHECKPOINT")
@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel):
path: str | None = None
init_ckpt_path: str | None = None
continue_training_from_init: bool = False
s3_profile: str | None = None
def _get_key_step(name: str):
return int(re.findall(RE_DIGITS, name)[-1])
def consolidate_checkpoints(ckpt_dir: str):
def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
"""
Consolidates all FSDP checkpoints in a directory to a single file
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
@ -102,15 +105,17 @@ def load_from_checkpoint(
dcp.load(state_dict, checkpoint_id=ckpt_dir)
# TODO: Rewrite the file operations here to use fsspec to enable s3 writing.
class CheckpointManager:
def __init__(self, args: CheckpointArgs):
self.path = args.path
self.fs = get_fs(self.path, s3_profile=args.s3_profile)
self.dump_every = args.dump
self.eval_every = args.eval
self.init_ckpt_path = args.init_ckpt_path
self.continue_training_from_init = args.continue_training_from_init
assert os.path.exists(
assert self.fs.exists(
self.path
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"

View file

@ -98,11 +98,4 @@ logging:
freq: 10
eval_on_gpus: 8
eval:
dataset_dir: /checkpoint/amaia/codegen/datasets/eval
tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
generator:
max_tokens: 65536
dtype: bf16
mp_size: 1
eval: null

View file

@ -72,11 +72,4 @@ logging:
freq: 10
eval_on_gpus: 8
eval:
dataset_dir: ???
tasks: ???
generator:
max_tokens: 65536
dtype: bf16
mp_size: 1
eval: null

View file

@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel):
n_views: int = 2
class DataLoaderState(BaseModel):
model_config = ConfigDict(extra="forbid")
multi_choice_state: MultiChoiceState
pack_tokens_state: BltPackTokensState
prefetch_state: PrefetchState
BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
class BltSequence(BaseModel):
tokens: list[int]
mask: list[bool]

View file

@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator):
self.producer = None
self.stop_iterating_event = None
self.state_dumped_event = None
self.force_shutdown = False
def shutdown(self):
if self.producer is not None:
# This properly shuts things down
self.producer.kill()
self.force_shutdown = True
def get_state(self) -> MultiprocessIteratorState:
"""
@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator):
to halt the background process and allow it to write the state to the main loop
in order to not lose data
"""
if self.force_shutdown:
raise ValueError(
"State will be invalid if shutdown was forced before state persisted."
)
if self.producer is None:
serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator):
)
def create_iter(self):
if self.force_shutdown:
raise ValueError(
"Iterator may be invalid if shutdown was forced before state persisted."
)
logging.info("Main thread: Creating MP iterator")
# First yield from the stored prefetch buffer.
if self.prefetch_buffer is not None:

View file

@ -4,20 +4,20 @@ import json
import logging
import os
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
from typing import Any
import torch
from lingua.args import dump_config
from lingua.data import init_choice_state, setup_sources
from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import (
DistributedArgs,
dist_mean_dict,
@ -25,72 +25,17 @@ from bytelatent.distributed import (
get_world_size,
setup_torch_distributed,
)
from bytelatent.transformer import LMTransformer, LMTransformerArgs
from apps.main.generate import (
from bytelatent.generate import (
PackedCausalTransformerGenerator,
PackedCausalTransformerGeneratorArgs,
load_consolidated_model_and_tokenizer,
)
from bytelatent.transformer import LMTransformer, LMTransformerArgs
EVAL_FOLDER_NAME = "{:010d}"
logger = logging.getLogger()
@dataclass
class LMHarnessArgs:
tasks: Optional[List[Any]] = None
num_fewshot: Optional[int] = None
device: Optional[str] = None
use_cache: Optional[str] = None
cache_requests: bool = False
rewrite_requests_cache: bool = False
delete_requests_cache: bool = False
limit: Optional[Union[int, float]] = None
bootstrap_iters: int = 100000
check_integrity: bool = False
write_out: bool = False
log_samples: bool = True
system_instruction: Optional[str] = None
apply_chat_template: Union[bool, str] = False
fewshot_as_multiturn: bool = False
gen_kwargs: Optional[str] = None
verbosity: str = "INFO"
predict_only: bool = False
random_seed: int = 0
numpy_random_seed: int = 1234
torch_random_seed: int = 1234
fewshot_random_seed: int = 1234
@dataclass
class ValidationArgs:
max_steps: Optional[int] = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = ""
sources: List[str] = field(default_factory=list) # Other sources to eval on
@dataclass
class EvalArgs:
name: str = "evals"
dump_dir: Optional[str] = None
metric_log_dir: Optional[str] = None
ckpt_dir: str = ""
generator: PackedCausalTransformerGeneratorArgs = field(
default_factory=PackedCausalTransformerGeneratorArgs
)
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
wandb: Optional[Any] = None
global_step: Optional[int] = None # for in-training evaluation
def all_dicts_same(dict_list):
if not dict_list: # Check if the list is empty
return True
@ -120,7 +65,7 @@ class EvalHarnessLM(LM):
self._world_size = get_world_size()
self.device = generator.device
def generate_until(self, requests: List[Instance]) -> List[str]:
def generate_until(self, requests: list[Instance]) -> list[str]:
prompts, gen_args = zip(*[req.args for req in requests])
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
gen_args = gen_args[0]
@ -141,7 +86,7 @@ class EvalHarnessLM(LM):
filtered_gen.append(g)
return filtered_gen
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
prompts, continuations = zip(*[req.args for req in requests])
inputs = [req.args[0] + req.args[1] for req in requests]
max_gen_len = self.generator.max_gen_len
@ -158,7 +103,7 @@ class EvalHarnessLM(LM):
self.generator.max_gen_len = max_gen_len
return results
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
prompts = [req.args[0] for req in requests]
max_gen_len = self.generator.max_gen_len
# We temporarily lower max gen len
@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
return all_val_metrics
def launch_eval(cfg: EvalArgs):
def launch_eval(eval_args: EvalArgs):
if not torch.distributed.is_initialized():
setup_torch_distributed(DistributedArgs())
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
if (
Path(cfg.ckpt_dir).exists()
and (Path(cfg.ckpt_dir) / "params.json").exists()
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
fs.exists(eval_args.ckpt_dir)
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
):
consolidate_path = Path(cfg.ckpt_dir)
consolidate_path = eval_args.ckpt_dir
else:
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
if not consolidate_path.exists() and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
f.write(eval_args.model_dump_json())
consolidate_path = str(consolidate_path)
torch.distributed.barrier()
logger.info("Loading model")
# TODO: Make this general so that it works with either
# LMTransformer or Blt, similar with args
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path,
model_cls=LMTransformer,
model_args_cls=LMTransformerArgs,
)
logger.info("Model loaded")
model.eval()
generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
wrap = EvalHarnessLM(generator)
results = simple_evaluate(wrap, **asdict(cfg.harness))
# Redo
results = simple_evaluate(wrap, eval_args.harness.model_dump())
val_results = None
if cfg.validation:
val_results = eval_on_val(generator, cfg.validation, train_cfg)
if eval_args.validation:
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
if get_global_rank() == 0:
with open(Path(cfg.dump_dir) / "results.json", "w") as f:
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
f.write(json.dumps(results))
logger.info(f"All evaluation results: {results['results']}")
if val_results is not None:
with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
f.write(json.dumps(val_results))
logger.info(f"All validation results: {val_results}")
if cfg.metric_log_dir and get_global_rank() == 0:
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
if eval_args.metric_log_dir and get_global_rank() == 0:
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
logger.info(f"Writing metric logs to {metric_log_path}")
timestamp = {
"created_at": datetime.utcnow().isoformat(),
}
if cfg.global_step is not None:
timestamp["global_step"] = cfg.global_step
if eval_args.global_step is not None:
timestamp["global_step"] = eval_args.global_step
print(
json.dumps(timestamp | results["results"]),
file=open(metric_log_path, mode="a"),
file=fs.open(metric_log_path, mode="a"),
flush=True,
)
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
val_log_path = os.path.join(
eval_args.metric_log_dir, "metrics.validation.jsonl"
)
if val_results is not None:
print(
json.dumps(timestamp | val_results),
file=open(val_log_path, mode="a"),
file=fs.open(val_log_path, mode="a"),
flush=True,
)
@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs):
def main():
"""
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
model: LMTransformerArgsgs
@dataclass
class LMTransformerArgsgs:
dim: int
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
or just name=tictac for top level attributes.
The behavior here is as follows:
1. We instantiate EvalArgs with its default values
2. We override those default values with the ones in the provided config file
3. We override the result with the additional arguments provided through command line
For example, if the config is the following
model:
dim: 128
n_layers: 4
and you call eval.py with eval.py model.dim=64
Then the final TrainArgs will have
model:
dim: 64
n_layers: 4
Plus all the default values in EvalArgs dataclass.
"""
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.structured(EvalArgs())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_object(cfg)
launch_eval(cfg)
eval_args = parse_args(EvalArgs)
launch_eval(eval_args)
if __name__ == "__main__":

View file

@ -1,20 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import torch
from lingua.args import dataclass_from_dict
from lingua.tokenizers.abstract_tokenizer import Tokenizer
from lingua.tokenizers.build_tokenizer import build_tokenizer
from omegaconf import OmegaConf
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm
from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs
from bytelatent.base_transformer import (
Attention,
causal_mask,
@ -23,7 +19,10 @@ from bytelatent.base_transformer import (
lengths_to_start_ids,
)
from bytelatent.checkpoint import CONSOLIDATE_NAME
from bytelatent.transformer import LMTransformer, LMTransformerArgs
from bytelatent.data.file_util import get_fs
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.transformer import LMTransformer
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
return next_token.view(shape[:-1])
def pack_prompts(prompts: List[int]):
def pack_prompts(prompts: list[int]):
res = []
lengths = []
for i, p in enumerate(prompts):
@ -120,22 +119,6 @@ class KVCache(nn.Module):
return self.k_cache, self.v_cache
@dataclass
class PackedCausalTransformerGeneratorArgs:
temperature: float = 0.0
top_p: Optional[float] = None
top_k: Optional[float] = None
max_gen_len: int = 512 # Maximum number of tokens to generate
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
max_prompt_len: Optional[int] = None
until: List[str] = field(default_factory=list)
compile_prefilling: bool = False
reduce_generation_overhead: bool = False
show_progress: bool = False
dtype: Optional[str] = "bf16"
device: Optional[str] = "cuda"
class PackedCausalTransformerGenerator:
def __init__(
self,
@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator:
def load_consolidated_model_and_tokenizer(
consolidated_path,
model_cls=LMTransformer,
model_args_cls=LMTransformerArgs,
):
ckpt_path = Path(consolidated_path)
config = ckpt_path / "params.json"
config = OmegaConf.load(config)
train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path)
with fs.open(train_args_path) as f:
train_args = TrainArgs.model_validate_json(f.read())
if train_args.train_entropy_model:
model_args = train_args.entropy_model
model = LMTransformer(model_args)
else:
model_args = train_args.model
model = ByteLatentTransformer(model_args)
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
config.distributed.model_dtype
train_args.distributed.model_dtype
]
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
model = model_cls(model_args)
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
tokenizer = train_args.data.tokenizer_args.build()
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
model.load_state_dict(st_dict["model"])
model = model.cuda().eval()
for param in model.parameters():
param.data = param.data.to(dtype=param_dtype)
return model, tokenizer, config
return model, tokenizer, train_args
def main():

View file

@ -10,7 +10,7 @@ from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from timeit import default_timer as timer
from typing import Any, Dict, Type, TypeVar
from typing import Any, TypeVar
import torch
import torch.distributed
@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import lr_scheduler
from bytelatent.args import TrainArgs
from bytelatent.args import TrainArgs, parse_args
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.data_types import DataLoaderState
from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
MultiprocessIteratorState,
)
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
from bytelatent.distributed import (
check_model_value_range,
clean_env,
@ -39,6 +43,7 @@ from bytelatent.distributed import (
setup_env,
setup_torch_distributed,
)
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
from bytelatent.logger import init_logger
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
from bytelatent.model.blt import ByteLatentTransformer
@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"):
return dict(items)
def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
"""
Converts a dictionary to a dataclass instance, recursively for nested structures.
"""
base = OmegaConf.structured(cls())
OmegaConf.set_struct(base, strict)
override = OmegaConf.create(data)
return OmegaConf.to_object(OmegaConf.merge(base, override))
def get_iterator_state_name(iterator_state):
if isinstance(iterator_state, MultiprocessIteratorState):
return "multiprocess"
elif isinstance(iterator_state, PackingIteratorState):
return "packing"
else:
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
# TODO: Make this pydantic based instead of data class based
# TODO: Generalize this to any iterator state
@dataclass
class TrainState(Stateful):
step: int # Nb of steps taken by the optimizer
acc_step: int # Nb of accumulation steps done since last optimizer step
scheduler: lr_scheduler.LambdaLR
data_loader_state: DataLoaderState
data_loader_state: MultiprocessIteratorState | PackingIteratorState
scale: float = 1.0
data_loader_class: str | None = None
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
return {
"step": self.step,
"acc_step": self.acc_step,
"data_loader_state": self.data_loader_state.dict(),
"data_loader_state": self.data_loader_state.model_dump(),
"data_loader_class": get_iterator_state_name(self.data_loader_state),
"scheduler": self.scheduler.state_dict(),
}
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
self.acc_step = state_dict["acc_step"]
self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"])
self.data_loader_class = state_dict["data_loader_class"]
if self.data_loader_class == "multiprocess":
self.data_loader_state = MultiprocessIteratorState(
**state_dict["data_loader_state"]
)
elif self.data_loader_class == "packing":
self.data_loader_state = PackingIteratorState(
**state_dict["data_loader_state"]
)
else:
raise ValueError(f"invalid data loader class: {self.data_loader_class}")
self.scheduler.load_state_dict(state_dict["scheduler"])
@ -345,7 +363,10 @@ def train(args: TrainArgs):
nwords_since_last_log = 0
time_last_log = timer()
gc.collect()
while train_state.step < args.steps:
saved = False
while train_state.step < args.steps and (
args.max_steps is None or train_state.step < args.max_steps
):
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
train_state.acc_step += 1
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
@ -552,7 +573,6 @@ def train(args: TrainArgs):
f" pow: {gpu_mem_stats.power_draw/1000} W"
)
saved = False
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
@ -567,18 +587,14 @@ def train(args: TrainArgs):
if args.eval is not None and every_n_steps(
train_state, args.checkpoint.eval.every, acc_step=0
):
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
eval_args = dataclass_from_dict(EvalArgs, args.eval)
eval_args = args.eval
eval_args.global_step = train_state.step
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
eval_args.dump_dir = str(
os.path.join(
args.dump_dir,
"evals",
EVAL_FOLDER_NAME.format(train_state.step),
)
eval_args.dump_dir = os.path.join(
args.dump_dir,
"evals",
EVAL_FOLDER_NAME.format(train_state.step),
)
eval_args.metric_log_dir = args.dump_dir
if args.async_eval_gpus is None:
@ -619,6 +635,9 @@ def train(args: TrainArgs):
args,
device_mesh=world_mesh,
)
if isinstance(data_loader, MultiprocessIterator):
logger.info("Closing MP iterator before exiting")
data_loader.shutdown()
gc.collect()
@ -661,15 +680,7 @@ def main():
Plus all the default values in TrainArgs dataclass.
"""
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.create(TrainArgs().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
train_args = TrainArgs.model_validate(cfg)
train_args = parse_args(TrainArgs)
if train_args.debug_dynamo:
import torch._dynamo