This includes fixes that make checkpointing and reloading work correctly. (#35)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

It also batches in a first set of changes for fixing eval code

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-27 16:56:42 -08:00 committed by GitHub
parent 7622d28b74
commit 7044771a12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 221 additions and 237 deletions

View file

@ -544,7 +544,7 @@ def train(args: TrainArgs):
if args.eval is not None and every_n_steps( if args.eval is not None and every_n_steps(
train_state, args.checkpoint.eval.every, acc_step=0 train_state, args.checkpoint.eval.every, acc_step=0
): ):
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
eval_args = dataclass_from_dict(EvalArgs, args.eval) eval_args = dataclass_from_dict(EvalArgs, args.eval)

View file

@ -5,6 +5,7 @@ from typing import Any
import numpy as np import numpy as np
import yaml import yaml
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs from bytelatent.checkpoint import CheckpointArgs
@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state return np.random.default_rng((seed, rank, world_size)).bit_generator.state
def parse_args(args_cls):
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.create(args_cls().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
pydantic_args = args_cls.model_validate(cfg)
return pydantic_args
def distribute_data_to_rank( def distribute_data_to_rank(
*, *,
dataset_path: str, dataset_path: str,
@ -71,6 +85,22 @@ def distribute_data_to_rank(
return rank_to_arrow_iterator_params[rank] return rank_to_arrow_iterator_params[rank]
class PackedCausalTransformerGeneratorArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
temperature: float = 0.0
top_p: float | None = None
top_k: float | None = None
max_gen_len: int = 512 # Maximum number of tokens to generate
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
max_prompt_len: int | None = None
until: list[str] = []
compile_prefilling: bool = False
reduce_generation_overhead: bool = False
show_progress: bool = False
dtype: str | None = "bf16"
device: str | None = "cuda"
class DataloaderArgs(BaseModel): class DataloaderArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
s3_profile: str | None = None s3_profile: str | None = None
@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel):
return packing_iterator return packing_iterator
class LMHarnessArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
tasks: list[Any] | None = None
num_fewshot: int | None = None
device: str | None = None
use_cache: str | None = None
cache_requests: bool = False
rewrite_requests_cache: bool = False
delete_requests_cache: bool = False
limit: int | float | None = None
bootstrap_iters: int = 100000
check_integrity: bool = False
write_out: bool = False
log_samples: bool = True
system_instruction: str | None = None
apply_chat_template: bool | str = False
fewshot_as_multiturn: bool = False
gen_kwargs: str | None = None
verbosity: str = "INFO"
predict_only: bool = False
random_seed: int = 0
numpy_random_seed: int = 1234
torch_random_seed: int = 1234
fewshot_random_seed: int = 1234
class ValidationArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
max_steps: int | None = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = ""
sources: list[str] = [] # Other sources to eval on
class EvalArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dump_dir: str
ckpt_dir: str
metric_log_dir: str | None = None
generator: PackedCausalTransformerGeneratorArgs = (
PackedCausalTransformerGeneratorArgs()
)
harness: LMHarnessArgs | None = LMHarnessArgs()
validation: ValidationArgs | None = ValidationArgs()
global_step: int | None = None # for in-training evaluation
s3_profile: str | None = None
class TrainArgs(BaseModel): class TrainArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
name: str = "lingua" name: str = "lingua"
@ -186,6 +268,9 @@ class TrainArgs(BaseModel):
# Nb optimizer steps to take # Nb optimizer steps to take
steps: int = 1000 steps: int = 1000
# If not None, halt training after this many steps,
# useful for debugging
max_steps: int | None = None
data: DataloaderArgs = DataloaderArgs() data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs() optim: OptimArgs = OptimArgs()
@ -203,7 +288,7 @@ class TrainArgs(BaseModel):
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
async_eval_gpus: int | None = None async_eval_gpus: int | None = None
eval: Any | None = None eval: EvalArgs | None = None
eval_on_gpus: int | None = None eval_on_gpus: int | None = None
def dump_to_yaml_file( def dump_to_yaml_file(

View file

@ -7,6 +7,7 @@ import re
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import fsspec
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import (
set_state_dict, set_state_dict,
) )
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import get_is_master from bytelatent.distributed import get_is_master
logger = logging.getLogger("CHECKPOINT") logger = logging.getLogger("CHECKPOINT")
@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel):
path: str | None = None path: str | None = None
init_ckpt_path: str | None = None init_ckpt_path: str | None = None
continue_training_from_init: bool = False continue_training_from_init: bool = False
s3_profile: str | None = None
def _get_key_step(name: str): def _get_key_step(name: str):
return int(re.findall(RE_DIGITS, name)[-1]) return int(re.findall(RE_DIGITS, name)[-1])
def consolidate_checkpoints(ckpt_dir: str): def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
""" """
Consolidates all FSDP checkpoints in a directory to a single file Consolidates all FSDP checkpoints in a directory to a single file
Consolidate checkpoint is saved in a subdirectory of ckpt_dir Consolidate checkpoint is saved in a subdirectory of ckpt_dir
@ -102,15 +105,17 @@ def load_from_checkpoint(
dcp.load(state_dict, checkpoint_id=ckpt_dir) dcp.load(state_dict, checkpoint_id=ckpt_dir)
# TODO: Rewrite the file operations here to use fsspec to enable s3 writing.
class CheckpointManager: class CheckpointManager:
def __init__(self, args: CheckpointArgs): def __init__(self, args: CheckpointArgs):
self.path = args.path self.path = args.path
self.fs = get_fs(self.path, s3_profile=args.s3_profile)
self.dump_every = args.dump self.dump_every = args.dump
self.eval_every = args.eval self.eval_every = args.eval
self.init_ckpt_path = args.init_ckpt_path self.init_ckpt_path = args.init_ckpt_path
self.continue_training_from_init = args.continue_training_from_init self.continue_training_from_init = args.continue_training_from_init
assert os.path.exists( assert self.fs.exists(
self.path self.path
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"

View file

@ -98,11 +98,4 @@ logging:
freq: 10 freq: 10
eval_on_gpus: 8 eval_on_gpus: 8
eval: eval: null
dataset_dir: /checkpoint/amaia/codegen/datasets/eval
tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
generator:
max_tokens: 65536
dtype: bf16
mp_size: 1

View file

@ -72,11 +72,4 @@ logging:
freq: 10 freq: 10
eval_on_gpus: 8 eval_on_gpus: 8
eval: eval: null
dataset_dir: ???
tasks: ???
generator:
max_tokens: 65536
dtype: bf16
mp_size: 1

View file

@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel):
n_views: int = 2 n_views: int = 2
class DataLoaderState(BaseModel):
model_config = ConfigDict(extra="forbid")
multi_choice_state: MultiChoiceState
pack_tokens_state: BltPackTokensState
prefetch_state: PrefetchState
BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
class BltSequence(BaseModel): class BltSequence(BaseModel):
tokens: list[int] tokens: list[int]
mask: list[bool] mask: list[bool]

View file

@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator):
self.producer = None self.producer = None
self.stop_iterating_event = None self.stop_iterating_event = None
self.state_dumped_event = None self.state_dumped_event = None
self.force_shutdown = False
def shutdown(self):
if self.producer is not None:
# This properly shuts things down
self.producer.kill()
self.force_shutdown = True
def get_state(self) -> MultiprocessIteratorState: def get_state(self) -> MultiprocessIteratorState:
""" """
@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator):
to halt the background process and allow it to write the state to the main loop to halt the background process and allow it to write the state to the main loop
in order to not lose data in order to not lose data
""" """
if self.force_shutdown:
raise ValueError(
"State will be invalid if shutdown was forced before state persisted."
)
if self.producer is None: if self.producer is None:
serialized_prefetch_buffer = json.dumps( serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer] [b.to_python_dict() for b in self.prefetch_buffer]
@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator):
) )
def create_iter(self): def create_iter(self):
if self.force_shutdown:
raise ValueError(
"Iterator may be invalid if shutdown was forced before state persisted."
)
logging.info("Main thread: Creating MP iterator") logging.info("Main thread: Creating MP iterator")
# First yield from the stored prefetch buffer. # First yield from the stored prefetch buffer.
if self.prefetch_buffer is not None: if self.prefetch_buffer is not None:

View file

@ -4,20 +4,20 @@ import json
import logging import logging
import os import os
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Tuple, Union from typing import Any
import torch import torch
from lingua.args import dump_config
from lingua.data import init_choice_state, setup_sources
from lm_eval import simple_evaluate from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import ( from bytelatent.distributed import (
DistributedArgs, DistributedArgs,
dist_mean_dict, dist_mean_dict,
@ -25,72 +25,17 @@ from bytelatent.distributed import (
get_world_size, get_world_size,
setup_torch_distributed, setup_torch_distributed,
) )
from bytelatent.transformer import LMTransformer, LMTransformerArgs from bytelatent.generate import (
from apps.main.generate import (
PackedCausalTransformerGenerator, PackedCausalTransformerGenerator,
PackedCausalTransformerGeneratorArgs,
load_consolidated_model_and_tokenizer, load_consolidated_model_and_tokenizer,
) )
from bytelatent.transformer import LMTransformer, LMTransformerArgs
EVAL_FOLDER_NAME = "{:010d}" EVAL_FOLDER_NAME = "{:010d}"
logger = logging.getLogger() logger = logging.getLogger()
@dataclass
class LMHarnessArgs:
tasks: Optional[List[Any]] = None
num_fewshot: Optional[int] = None
device: Optional[str] = None
use_cache: Optional[str] = None
cache_requests: bool = False
rewrite_requests_cache: bool = False
delete_requests_cache: bool = False
limit: Optional[Union[int, float]] = None
bootstrap_iters: int = 100000
check_integrity: bool = False
write_out: bool = False
log_samples: bool = True
system_instruction: Optional[str] = None
apply_chat_template: Union[bool, str] = False
fewshot_as_multiturn: bool = False
gen_kwargs: Optional[str] = None
verbosity: str = "INFO"
predict_only: bool = False
random_seed: int = 0
numpy_random_seed: int = 1234
torch_random_seed: int = 1234
fewshot_random_seed: int = 1234
@dataclass
class ValidationArgs:
max_steps: Optional[int] = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = ""
sources: List[str] = field(default_factory=list) # Other sources to eval on
@dataclass
class EvalArgs:
name: str = "evals"
dump_dir: Optional[str] = None
metric_log_dir: Optional[str] = None
ckpt_dir: str = ""
generator: PackedCausalTransformerGeneratorArgs = field(
default_factory=PackedCausalTransformerGeneratorArgs
)
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
wandb: Optional[Any] = None
global_step: Optional[int] = None # for in-training evaluation
def all_dicts_same(dict_list): def all_dicts_same(dict_list):
if not dict_list: # Check if the list is empty if not dict_list: # Check if the list is empty
return True return True
@ -120,7 +65,7 @@ class EvalHarnessLM(LM):
self._world_size = get_world_size() self._world_size = get_world_size()
self.device = generator.device self.device = generator.device
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: list[Instance]) -> list[str]:
prompts, gen_args = zip(*[req.args for req in requests]) prompts, gen_args = zip(*[req.args for req in requests])
assert all_dicts_same(gen_args), "Doesn't support different gen args for now" assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
gen_args = gen_args[0] gen_args = gen_args[0]
@ -141,7 +86,7 @@ class EvalHarnessLM(LM):
filtered_gen.append(g) filtered_gen.append(g)
return filtered_gen return filtered_gen
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
prompts, continuations = zip(*[req.args for req in requests]) prompts, continuations = zip(*[req.args for req in requests])
inputs = [req.args[0] + req.args[1] for req in requests] inputs = [req.args[0] + req.args[1] for req in requests]
max_gen_len = self.generator.max_gen_len max_gen_len = self.generator.max_gen_len
@ -158,7 +103,7 @@ class EvalHarnessLM(LM):
self.generator.max_gen_len = max_gen_len self.generator.max_gen_len = max_gen_len
return results return results
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
prompts = [req.args[0] for req in requests] prompts = [req.args[0] for req in requests]
max_gen_len = self.generator.max_gen_len max_gen_len = self.generator.max_gen_len
# We temporarily lower max gen len # We temporarily lower max gen len
@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
return all_val_metrics return all_val_metrics
def launch_eval(cfg: EvalArgs): def launch_eval(eval_args: EvalArgs):
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
setup_torch_distributed(DistributedArgs()) setup_torch_distributed(DistributedArgs())
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
if ( if (
Path(cfg.ckpt_dir).exists() fs.exists(eval_args.ckpt_dir)
and (Path(cfg.ckpt_dir) / "params.json").exists() and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
): ):
consolidate_path = Path(cfg.ckpt_dir) consolidate_path = eval_args.ckpt_dir
else: else:
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
if not consolidate_path.exists() and get_global_rank() == 0: if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) fs.mkdirs(eval_args.dump_dir, exist_ok=True)
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
f.write(eval_args.model_dump_json())
consolidate_path = str(consolidate_path)
torch.distributed.barrier() torch.distributed.barrier()
logger.info("Loading model") logger.info("Loading model")
# TODO: Make this general so that it works with either
# LMTransformer or Blt, similar with args
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path, consolidate_path,
model_cls=LMTransformer,
model_args_cls=LMTransformerArgs,
) )
logger.info("Model loaded") logger.info("Model loaded")
model.eval() model.eval()
generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
wrap = EvalHarnessLM(generator) wrap = EvalHarnessLM(generator)
results = simple_evaluate(wrap, **asdict(cfg.harness)) # Redo
results = simple_evaluate(wrap, eval_args.harness.model_dump())
val_results = None val_results = None
if cfg.validation: if eval_args.validation:
val_results = eval_on_val(generator, cfg.validation, train_cfg) val_results = eval_on_val(generator, eval_args.validation, train_cfg)
if get_global_rank() == 0: if get_global_rank() == 0:
with open(Path(cfg.dump_dir) / "results.json", "w") as f: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
f.write(json.dumps(results)) f.write(json.dumps(results))
logger.info(f"All evaluation results: {results['results']}") logger.info(f"All evaluation results: {results['results']}")
if val_results is not None: if val_results is not None:
with open(Path(cfg.dump_dir) / "validation.json", "w") as f: with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
f.write(json.dumps(val_results)) f.write(json.dumps(val_results))
logger.info(f"All validation results: {val_results}") logger.info(f"All validation results: {val_results}")
if cfg.metric_log_dir and get_global_rank() == 0: if eval_args.metric_log_dir and get_global_rank() == 0:
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
logger.info(f"Writing metric logs to {metric_log_path}") logger.info(f"Writing metric logs to {metric_log_path}")
timestamp = { timestamp = {
"created_at": datetime.utcnow().isoformat(), "created_at": datetime.utcnow().isoformat(),
} }
if cfg.global_step is not None: if eval_args.global_step is not None:
timestamp["global_step"] = cfg.global_step timestamp["global_step"] = eval_args.global_step
print( print(
json.dumps(timestamp | results["results"]), json.dumps(timestamp | results["results"]),
file=open(metric_log_path, mode="a"), file=fs.open(metric_log_path, mode="a"),
flush=True, flush=True,
) )
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" val_log_path = os.path.join(
eval_args.metric_log_dir, "metrics.validation.jsonl"
)
if val_results is not None: if val_results is not None:
print( print(
json.dumps(timestamp | val_results), json.dumps(timestamp | val_results),
file=open(val_log_path, mode="a"), file=fs.open(val_log_path, mode="a"),
flush=True, flush=True,
) )
@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs):
def main(): def main():
""" eval_args = parse_args(EvalArgs)
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments launch_eval(eval_args)
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
model: LMTransformerArgsgs
@dataclass
class LMTransformerArgsgs:
dim: int
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
or just name=tictac for top level attributes.
The behavior here is as follows:
1. We instantiate EvalArgs with its default values
2. We override those default values with the ones in the provided config file
3. We override the result with the additional arguments provided through command line
For example, if the config is the following
model:
dim: 128
n_layers: 4
and you call eval.py with eval.py model.dim=64
Then the final TrainArgs will have
model:
dim: 64
n_layers: 4
Plus all the default values in EvalArgs dataclass.
"""
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.structured(EvalArgs())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_object(cfg)
launch_eval(cfg)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,20 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import os
import time import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import torch import torch
from lingua.args import dataclass_from_dict
from lingua.tokenizers.abstract_tokenizer import Tokenizer
from lingua.tokenizers.build_tokenizer import build_tokenizer
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm from tqdm import tqdm
from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs
from bytelatent.base_transformer import ( from bytelatent.base_transformer import (
Attention, Attention,
causal_mask, causal_mask,
@ -23,7 +19,10 @@ from bytelatent.base_transformer import (
lengths_to_start_ids, lengths_to_start_ids,
) )
from bytelatent.checkpoint import CONSOLIDATE_NAME from bytelatent.checkpoint import CONSOLIDATE_NAME
from bytelatent.transformer import LMTransformer, LMTransformerArgs from bytelatent.data.file_util import get_fs
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.transformer import LMTransformer
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
return next_token.view(shape[:-1]) return next_token.view(shape[:-1])
def pack_prompts(prompts: List[int]): def pack_prompts(prompts: list[int]):
res = [] res = []
lengths = [] lengths = []
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
@ -120,22 +119,6 @@ class KVCache(nn.Module):
return self.k_cache, self.v_cache return self.k_cache, self.v_cache
@dataclass
class PackedCausalTransformerGeneratorArgs:
temperature: float = 0.0
top_p: Optional[float] = None
top_k: Optional[float] = None
max_gen_len: int = 512 # Maximum number of tokens to generate
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
max_prompt_len: Optional[int] = None
until: List[str] = field(default_factory=list)
compile_prefilling: bool = False
reduce_generation_overhead: bool = False
show_progress: bool = False
dtype: Optional[str] = "bf16"
device: Optional[str] = "cuda"
class PackedCausalTransformerGenerator: class PackedCausalTransformerGenerator:
def __init__( def __init__(
self, self,
@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator:
def load_consolidated_model_and_tokenizer( def load_consolidated_model_and_tokenizer(
consolidated_path, consolidated_path,
model_cls=LMTransformer,
model_args_cls=LMTransformerArgs,
): ):
ckpt_path = Path(consolidated_path) train_args_path = os.path.join(consolidated_path, "params.json")
config = ckpt_path / "params.json" fs = get_fs(train_args_path)
config = OmegaConf.load(config) with fs.open(train_args_path) as f:
train_args = TrainArgs.model_validate_json(f.read())
if train_args.train_entropy_model:
model_args = train_args.entropy_model
model = LMTransformer(model_args)
else:
model_args = train_args.model
model = ByteLatentTransformer(model_args)
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
config.distributed.model_dtype train_args.distributed.model_dtype
] ]
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) tokenizer = train_args.data.tokenizer_args.build()
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
model = model_cls(model_args)
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
model.load_state_dict(st_dict["model"]) model.load_state_dict(st_dict["model"])
model = model.cuda().eval() model = model.cuda().eval()
for param in model.parameters(): for param in model.parameters():
param.data = param.data.to(dtype=param_dtype) param.data = param.data.to(dtype=param_dtype)
return model, tokenizer, config return model, tokenizer, train_args
def main(): def main():

View file

@ -10,7 +10,7 @@ from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from pathlib import Path from pathlib import Path
from timeit import default_timer as timer from timeit import default_timer as timer
from typing import Any, Dict, Type, TypeVar from typing import Any, TypeVar
import torch import torch
import torch.distributed import torch.distributed
@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
from bytelatent.args import TrainArgs from bytelatent.args import TrainArgs, parse_args
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.data_types import DataLoaderState from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
MultiprocessIteratorState,
)
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
from bytelatent.distributed import ( from bytelatent.distributed import (
check_model_value_range, check_model_value_range,
clean_env, clean_env,
@ -39,6 +43,7 @@ from bytelatent.distributed import (
setup_env, setup_env,
setup_torch_distributed, setup_torch_distributed,
) )
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
from bytelatent.logger import init_logger from bytelatent.logger import init_logger
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
from bytelatent.model.blt import ByteLatentTransformer from bytelatent.model.blt import ByteLatentTransformer
@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"):
return dict(items) return dict(items)
def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: def get_iterator_state_name(iterator_state):
""" if isinstance(iterator_state, MultiprocessIteratorState):
Converts a dictionary to a dataclass instance, recursively for nested structures. return "multiprocess"
""" elif isinstance(iterator_state, PackingIteratorState):
base = OmegaConf.structured(cls()) return "packing"
OmegaConf.set_struct(base, strict) else:
override = OmegaConf.create(data) raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
return OmegaConf.to_object(OmegaConf.merge(base, override))
# TODO: Make this pydantic based instead of data class based
# TODO: Generalize this to any iterator state
@dataclass @dataclass
class TrainState(Stateful): class TrainState(Stateful):
step: int # Nb of steps taken by the optimizer step: int # Nb of steps taken by the optimizer
acc_step: int # Nb of accumulation steps done since last optimizer step acc_step: int # Nb of accumulation steps done since last optimizer step
scheduler: lr_scheduler.LambdaLR scheduler: lr_scheduler.LambdaLR
data_loader_state: DataLoaderState data_loader_state: MultiprocessIteratorState | PackingIteratorState
scale: float = 1.0 scale: float = 1.0
data_loader_class: str | None = None
def state_dict(self) -> Dict[str, Any]: def state_dict(self) -> dict[str, Any]:
return { return {
"step": self.step, "step": self.step,
"acc_step": self.acc_step, "acc_step": self.acc_step,
"data_loader_state": self.data_loader_state.dict(), "data_loader_state": self.data_loader_state.model_dump(),
"data_loader_class": get_iterator_state_name(self.data_loader_state),
"scheduler": self.scheduler.state_dict(), "scheduler": self.scheduler.state_dict(),
} }
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self.step = state_dict["step"] self.step = state_dict["step"]
self.acc_step = state_dict["acc_step"] self.acc_step = state_dict["acc_step"]
self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) self.data_loader_class = state_dict["data_loader_class"]
if self.data_loader_class == "multiprocess":
self.data_loader_state = MultiprocessIteratorState(
**state_dict["data_loader_state"]
)
elif self.data_loader_class == "packing":
self.data_loader_state = PackingIteratorState(
**state_dict["data_loader_state"]
)
else:
raise ValueError(f"invalid data loader class: {self.data_loader_class}")
self.scheduler.load_state_dict(state_dict["scheduler"]) self.scheduler.load_state_dict(state_dict["scheduler"])
@ -345,7 +363,10 @@ def train(args: TrainArgs):
nwords_since_last_log = 0 nwords_since_last_log = 0
time_last_log = timer() time_last_log = timer()
gc.collect() gc.collect()
while train_state.step < args.steps: saved = False
while train_state.step < args.steps and (
args.max_steps is None or train_state.step < args.max_steps
):
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
train_state.acc_step += 1 train_state.acc_step += 1
train_state.acc_step = train_state.acc_step % args.grad_acc_steps train_state.acc_step = train_state.acc_step % args.grad_acc_steps
@ -552,7 +573,6 @@ def train(args: TrainArgs):
f" pow: {gpu_mem_stats.power_draw/1000} W" f" pow: {gpu_mem_stats.power_draw/1000} W"
) )
saved = False
if every_n_steps( if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0 train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
@ -567,19 +587,15 @@ def train(args: TrainArgs):
if args.eval is not None and every_n_steps( if args.eval is not None and every_n_steps(
train_state, args.checkpoint.eval.every, acc_step=0 train_state, args.checkpoint.eval.every, acc_step=0
): ):
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval eval_args = args.eval
eval_args = dataclass_from_dict(EvalArgs, args.eval)
eval_args.global_step = train_state.step eval_args.global_step = train_state.step
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
eval_args.dump_dir = str( eval_args.dump_dir = os.path.join(
os.path.join(
args.dump_dir, args.dump_dir,
"evals", "evals",
EVAL_FOLDER_NAME.format(train_state.step), EVAL_FOLDER_NAME.format(train_state.step),
) )
)
eval_args.metric_log_dir = args.dump_dir eval_args.metric_log_dir = args.dump_dir
if args.async_eval_gpus is None: if args.async_eval_gpus is None:
launch_eval(eval_args) launch_eval(eval_args)
@ -619,6 +635,9 @@ def train(args: TrainArgs):
args, args,
device_mesh=world_mesh, device_mesh=world_mesh,
) )
if isinstance(data_loader, MultiprocessIterator):
logger.info("Closing MP iterator before exiting")
data_loader.shutdown()
gc.collect() gc.collect()
@ -661,15 +680,7 @@ def main():
Plus all the default values in TrainArgs dataclass. Plus all the default values in TrainArgs dataclass.
""" """
cli_args = OmegaConf.from_cli() train_args = parse_args(TrainArgs)
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.create(TrainArgs().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
train_args = TrainArgs.model_validate(cfg)
if train_args.debug_dynamo: if train_args.debug_dynamo:
import torch._dynamo import torch._dynamo