mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-30 17:42:14 +00:00
This includes fixes that make checkpointing and reloading work correctly. (#35)
It also batches in a first set of changes for fixing eval code Summary: Test Plan:
This commit is contained in:
parent
7622d28b74
commit
7044771a12
|
@ -544,7 +544,7 @@ def train(args: TrainArgs):
|
|||
if args.eval is not None and every_n_steps(
|
||||
train_state, args.checkpoint.eval.every, acc_step=0
|
||||
):
|
||||
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
||||
from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
||||
|
||||
eval_args = dataclass_from_dict(EvalArgs, args.eval)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Any
|
|||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bytelatent.checkpoint import CheckpointArgs
|
||||
|
@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
|
|||
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
|
||||
|
||||
|
||||
def parse_args(args_cls):
|
||||
cli_args = OmegaConf.from_cli()
|
||||
file_cfg = OmegaConf.load(cli_args.config)
|
||||
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
||||
del cli_args.config
|
||||
|
||||
default_cfg = OmegaConf.create(args_cls().model_dump())
|
||||
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||||
pydantic_args = args_cls.model_validate(cfg)
|
||||
return pydantic_args
|
||||
|
||||
|
||||
def distribute_data_to_rank(
|
||||
*,
|
||||
dataset_path: str,
|
||||
|
@ -71,6 +85,22 @@ def distribute_data_to_rank(
|
|||
return rank_to_arrow_iterator_params[rank]
|
||||
|
||||
|
||||
class PackedCausalTransformerGeneratorArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
temperature: float = 0.0
|
||||
top_p: float | None = None
|
||||
top_k: float | None = None
|
||||
max_gen_len: int = 512 # Maximum number of tokens to generate
|
||||
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
|
||||
max_prompt_len: int | None = None
|
||||
until: list[str] = []
|
||||
compile_prefilling: bool = False
|
||||
reduce_generation_overhead: bool = False
|
||||
show_progress: bool = False
|
||||
dtype: str | None = "bf16"
|
||||
device: str | None = "cuda"
|
||||
|
||||
|
||||
class DataloaderArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
s3_profile: str | None = None
|
||||
|
@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel):
|
|||
return packing_iterator
|
||||
|
||||
|
||||
class LMHarnessArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
tasks: list[Any] | None = None
|
||||
num_fewshot: int | None = None
|
||||
device: str | None = None
|
||||
use_cache: str | None = None
|
||||
cache_requests: bool = False
|
||||
rewrite_requests_cache: bool = False
|
||||
delete_requests_cache: bool = False
|
||||
limit: int | float | None = None
|
||||
bootstrap_iters: int = 100000
|
||||
check_integrity: bool = False
|
||||
write_out: bool = False
|
||||
log_samples: bool = True
|
||||
system_instruction: str | None = None
|
||||
apply_chat_template: bool | str = False
|
||||
fewshot_as_multiturn: bool = False
|
||||
gen_kwargs: str | None = None
|
||||
verbosity: str = "INFO"
|
||||
predict_only: bool = False
|
||||
random_seed: int = 0
|
||||
numpy_random_seed: int = 1234
|
||||
torch_random_seed: int = 1234
|
||||
fewshot_random_seed: int = 1234
|
||||
|
||||
|
||||
class ValidationArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
max_steps: int | None = (
|
||||
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
||||
)
|
||||
use_val_from_train_src: bool = True # Use the validation set from training sources
|
||||
root_dir: str = ""
|
||||
sources: list[str] = [] # Other sources to eval on
|
||||
|
||||
|
||||
class EvalArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
dump_dir: str
|
||||
ckpt_dir: str
|
||||
metric_log_dir: str | None = None
|
||||
generator: PackedCausalTransformerGeneratorArgs = (
|
||||
PackedCausalTransformerGeneratorArgs()
|
||||
)
|
||||
|
||||
harness: LMHarnessArgs | None = LMHarnessArgs()
|
||||
validation: ValidationArgs | None = ValidationArgs()
|
||||
|
||||
global_step: int | None = None # for in-training evaluation
|
||||
s3_profile: str | None = None
|
||||
|
||||
|
||||
class TrainArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
name: str = "lingua"
|
||||
|
@ -186,6 +268,9 @@ class TrainArgs(BaseModel):
|
|||
|
||||
# Nb optimizer steps to take
|
||||
steps: int = 1000
|
||||
# If not None, halt training after this many steps,
|
||||
# useful for debugging
|
||||
max_steps: int | None = None
|
||||
|
||||
data: DataloaderArgs = DataloaderArgs()
|
||||
optim: OptimArgs = OptimArgs()
|
||||
|
@ -203,7 +288,7 @@ class TrainArgs(BaseModel):
|
|||
|
||||
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
|
||||
async_eval_gpus: int | None = None
|
||||
eval: Any | None = None
|
||||
eval: EvalArgs | None = None
|
||||
eval_on_gpus: int | None = None
|
||||
|
||||
def dump_to_yaml_file(
|
||||
|
|
|
@ -7,6 +7,7 @@ import re
|
|||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
|
@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
set_state_dict,
|
||||
)
|
||||
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.distributed import get_is_master
|
||||
|
||||
logger = logging.getLogger("CHECKPOINT")
|
||||
|
@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel):
|
|||
path: str | None = None
|
||||
init_ckpt_path: str | None = None
|
||||
continue_training_from_init: bool = False
|
||||
s3_profile: str | None = None
|
||||
|
||||
|
||||
def _get_key_step(name: str):
|
||||
return int(re.findall(RE_DIGITS, name)[-1])
|
||||
|
||||
|
||||
def consolidate_checkpoints(ckpt_dir: str):
|
||||
def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
|
||||
"""
|
||||
Consolidates all FSDP checkpoints in a directory to a single file
|
||||
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
|
||||
|
@ -102,15 +105,17 @@ def load_from_checkpoint(
|
|||
dcp.load(state_dict, checkpoint_id=ckpt_dir)
|
||||
|
||||
|
||||
# TODO: Rewrite the file operations here to use fsspec to enable s3 writing.
|
||||
class CheckpointManager:
|
||||
def __init__(self, args: CheckpointArgs):
|
||||
self.path = args.path
|
||||
self.fs = get_fs(self.path, s3_profile=args.s3_profile)
|
||||
self.dump_every = args.dump
|
||||
self.eval_every = args.eval
|
||||
self.init_ckpt_path = args.init_ckpt_path
|
||||
self.continue_training_from_init = args.continue_training_from_init
|
||||
|
||||
assert os.path.exists(
|
||||
assert self.fs.exists(
|
||||
self.path
|
||||
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
|
||||
|
||||
|
|
|
@ -98,11 +98,4 @@ logging:
|
|||
freq: 10
|
||||
|
||||
eval_on_gpus: 8
|
||||
eval:
|
||||
dataset_dir: /checkpoint/amaia/codegen/datasets/eval
|
||||
tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
|
||||
generator:
|
||||
max_tokens: 65536
|
||||
dtype: bf16
|
||||
|
||||
mp_size: 1
|
||||
eval: null
|
||||
|
|
|
@ -72,11 +72,4 @@ logging:
|
|||
freq: 10
|
||||
|
||||
eval_on_gpus: 8
|
||||
eval:
|
||||
dataset_dir: ???
|
||||
tasks: ???
|
||||
generator:
|
||||
max_tokens: 65536
|
||||
dtype: bf16
|
||||
|
||||
mp_size: 1
|
||||
eval: null
|
||||
|
|
|
@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel):
|
|||
n_views: int = 2
|
||||
|
||||
|
||||
class DataLoaderState(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
multi_choice_state: MultiChoiceState
|
||||
pack_tokens_state: BltPackTokensState
|
||||
prefetch_state: PrefetchState
|
||||
|
||||
|
||||
BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
|
||||
|
||||
|
||||
class BltSequence(BaseModel):
|
||||
tokens: list[int]
|
||||
mask: list[bool]
|
||||
|
|
|
@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator):
|
|||
self.producer = None
|
||||
self.stop_iterating_event = None
|
||||
self.state_dumped_event = None
|
||||
self.force_shutdown = False
|
||||
|
||||
def shutdown(self):
|
||||
if self.producer is not None:
|
||||
# This properly shuts things down
|
||||
self.producer.kill()
|
||||
self.force_shutdown = True
|
||||
|
||||
def get_state(self) -> MultiprocessIteratorState:
|
||||
"""
|
||||
|
@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator):
|
|||
to halt the background process and allow it to write the state to the main loop
|
||||
in order to not lose data
|
||||
"""
|
||||
if self.force_shutdown:
|
||||
raise ValueError(
|
||||
"State will be invalid if shutdown was forced before state persisted."
|
||||
)
|
||||
if self.producer is None:
|
||||
serialized_prefetch_buffer = json.dumps(
|
||||
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||
|
@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator):
|
|||
)
|
||||
|
||||
def create_iter(self):
|
||||
if self.force_shutdown:
|
||||
raise ValueError(
|
||||
"Iterator may be invalid if shutdown was forced before state persisted."
|
||||
)
|
||||
logging.info("Main thread: Creating MP iterator")
|
||||
# First yield from the stored prefetch buffer.
|
||||
if self.prefetch_buffer is not None:
|
||||
|
|
|
@ -4,20 +4,20 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from lingua.args import dump_config
|
||||
from lingua.data import init_choice_state, setup_sources
|
||||
from lm_eval import simple_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
|
||||
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.distributed import (
|
||||
DistributedArgs,
|
||||
dist_mean_dict,
|
||||
|
@ -25,72 +25,17 @@ from bytelatent.distributed import (
|
|||
get_world_size,
|
||||
setup_torch_distributed,
|
||||
)
|
||||
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
||||
|
||||
from apps.main.generate import (
|
||||
from bytelatent.generate import (
|
||||
PackedCausalTransformerGenerator,
|
||||
PackedCausalTransformerGeneratorArgs,
|
||||
load_consolidated_model_and_tokenizer,
|
||||
)
|
||||
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
||||
|
||||
EVAL_FOLDER_NAME = "{:010d}"
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMHarnessArgs:
|
||||
tasks: Optional[List[Any]] = None
|
||||
num_fewshot: Optional[int] = None
|
||||
device: Optional[str] = None
|
||||
use_cache: Optional[str] = None
|
||||
cache_requests: bool = False
|
||||
rewrite_requests_cache: bool = False
|
||||
delete_requests_cache: bool = False
|
||||
limit: Optional[Union[int, float]] = None
|
||||
bootstrap_iters: int = 100000
|
||||
check_integrity: bool = False
|
||||
write_out: bool = False
|
||||
log_samples: bool = True
|
||||
system_instruction: Optional[str] = None
|
||||
apply_chat_template: Union[bool, str] = False
|
||||
fewshot_as_multiturn: bool = False
|
||||
gen_kwargs: Optional[str] = None
|
||||
verbosity: str = "INFO"
|
||||
predict_only: bool = False
|
||||
random_seed: int = 0
|
||||
numpy_random_seed: int = 1234
|
||||
torch_random_seed: int = 1234
|
||||
fewshot_random_seed: int = 1234
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationArgs:
|
||||
max_steps: Optional[int] = (
|
||||
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
||||
)
|
||||
use_val_from_train_src: bool = True # Use the validation set from training sources
|
||||
root_dir: str = ""
|
||||
sources: List[str] = field(default_factory=list) # Other sources to eval on
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalArgs:
|
||||
name: str = "evals"
|
||||
dump_dir: Optional[str] = None
|
||||
metric_log_dir: Optional[str] = None
|
||||
ckpt_dir: str = ""
|
||||
generator: PackedCausalTransformerGeneratorArgs = field(
|
||||
default_factory=PackedCausalTransformerGeneratorArgs
|
||||
)
|
||||
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
|
||||
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
|
||||
|
||||
wandb: Optional[Any] = None
|
||||
|
||||
global_step: Optional[int] = None # for in-training evaluation
|
||||
|
||||
|
||||
def all_dicts_same(dict_list):
|
||||
if not dict_list: # Check if the list is empty
|
||||
return True
|
||||
|
@ -120,7 +65,7 @@ class EvalHarnessLM(LM):
|
|||
self._world_size = get_world_size()
|
||||
self.device = generator.device
|
||||
|
||||
def generate_until(self, requests: List[Instance]) -> List[str]:
|
||||
def generate_until(self, requests: list[Instance]) -> list[str]:
|
||||
prompts, gen_args = zip(*[req.args for req in requests])
|
||||
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
|
||||
gen_args = gen_args[0]
|
||||
|
@ -141,7 +86,7 @@ class EvalHarnessLM(LM):
|
|||
filtered_gen.append(g)
|
||||
return filtered_gen
|
||||
|
||||
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
||||
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
||||
prompts, continuations = zip(*[req.args for req in requests])
|
||||
inputs = [req.args[0] + req.args[1] for req in requests]
|
||||
max_gen_len = self.generator.max_gen_len
|
||||
|
@ -158,7 +103,7 @@ class EvalHarnessLM(LM):
|
|||
self.generator.max_gen_len = max_gen_len
|
||||
return results
|
||||
|
||||
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
||||
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
|
||||
prompts = [req.args[0] for req in requests]
|
||||
max_gen_len = self.generator.max_gen_len
|
||||
# We temporarily lower max gen len
|
||||
|
@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
|||
return all_val_metrics
|
||||
|
||||
|
||||
def launch_eval(cfg: EvalArgs):
|
||||
def launch_eval(eval_args: EvalArgs):
|
||||
if not torch.distributed.is_initialized():
|
||||
setup_torch_distributed(DistributedArgs())
|
||||
|
||||
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
|
||||
if (
|
||||
Path(cfg.ckpt_dir).exists()
|
||||
and (Path(cfg.ckpt_dir) / "params.json").exists()
|
||||
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
|
||||
fs.exists(eval_args.ckpt_dir)
|
||||
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
|
||||
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
|
||||
):
|
||||
consolidate_path = Path(cfg.ckpt_dir)
|
||||
consolidate_path = eval_args.ckpt_dir
|
||||
else:
|
||||
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
|
||||
if not consolidate_path.exists() and get_global_rank() == 0:
|
||||
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
|
||||
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
||||
|
||||
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
|
||||
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
|
||||
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||
f.write(eval_args.model_dump_json())
|
||||
|
||||
consolidate_path = str(consolidate_path)
|
||||
torch.distributed.barrier()
|
||||
logger.info("Loading model")
|
||||
# TODO: Make this general so that it works with either
|
||||
# LMTransformer or Blt, similar with args
|
||||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||
consolidate_path,
|
||||
model_cls=LMTransformer,
|
||||
model_args_cls=LMTransformerArgs,
|
||||
)
|
||||
logger.info("Model loaded")
|
||||
model.eval()
|
||||
generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
|
||||
generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
|
||||
|
||||
wrap = EvalHarnessLM(generator)
|
||||
results = simple_evaluate(wrap, **asdict(cfg.harness))
|
||||
# Redo
|
||||
results = simple_evaluate(wrap, eval_args.harness.model_dump())
|
||||
val_results = None
|
||||
if cfg.validation:
|
||||
val_results = eval_on_val(generator, cfg.validation, train_cfg)
|
||||
if eval_args.validation:
|
||||
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
||||
if get_global_rank() == 0:
|
||||
with open(Path(cfg.dump_dir) / "results.json", "w") as f:
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
||||
f.write(json.dumps(results))
|
||||
logger.info(f"All evaluation results: {results['results']}")
|
||||
if val_results is not None:
|
||||
with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
||||
f.write(json.dumps(val_results))
|
||||
logger.info(f"All validation results: {val_results}")
|
||||
if cfg.metric_log_dir and get_global_rank() == 0:
|
||||
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
|
||||
if eval_args.metric_log_dir and get_global_rank() == 0:
|
||||
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
||||
|
||||
logger.info(f"Writing metric logs to {metric_log_path}")
|
||||
timestamp = {
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
if cfg.global_step is not None:
|
||||
timestamp["global_step"] = cfg.global_step
|
||||
if eval_args.global_step is not None:
|
||||
timestamp["global_step"] = eval_args.global_step
|
||||
print(
|
||||
json.dumps(timestamp | results["results"]),
|
||||
file=open(metric_log_path, mode="a"),
|
||||
file=fs.open(metric_log_path, mode="a"),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
|
||||
val_log_path = os.path.join(
|
||||
eval_args.metric_log_dir, "metrics.validation.jsonl"
|
||||
)
|
||||
if val_results is not None:
|
||||
print(
|
||||
json.dumps(timestamp | val_results),
|
||||
file=open(val_log_path, mode="a"),
|
||||
file=fs.open(val_log_path, mode="a"),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs):
|
|||
|
||||
|
||||
def main():
|
||||
"""
|
||||
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
||||
This accepts arguments as a dot list
|
||||
So if the dataclass looks like
|
||||
|
||||
@dataclass
|
||||
class DummyArgs:
|
||||
name: str
|
||||
model: LMTransformerArgsgs
|
||||
|
||||
@dataclass
|
||||
class LMTransformerArgsgs:
|
||||
dim: int
|
||||
|
||||
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
||||
or just name=tictac for top level attributes.
|
||||
|
||||
The behavior here is as follows:
|
||||
1. We instantiate EvalArgs with its default values
|
||||
2. We override those default values with the ones in the provided config file
|
||||
3. We override the result with the additional arguments provided through command line
|
||||
|
||||
For example, if the config is the following
|
||||
|
||||
model:
|
||||
dim: 128
|
||||
n_layers: 4
|
||||
|
||||
and you call eval.py with eval.py model.dim=64
|
||||
|
||||
Then the final TrainArgs will have
|
||||
|
||||
model:
|
||||
dim: 64
|
||||
n_layers: 4
|
||||
|
||||
Plus all the default values in EvalArgs dataclass.
|
||||
"""
|
||||
cli_args = OmegaConf.from_cli()
|
||||
file_cfg = OmegaConf.load(cli_args.config)
|
||||
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
||||
del cli_args.config
|
||||
|
||||
default_cfg = OmegaConf.structured(EvalArgs())
|
||||
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||
cfg = OmegaConf.to_object(cfg)
|
||||
launch_eval(cfg)
|
||||
eval_args = parse_args(EvalArgs)
|
||||
launch_eval(eval_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
|
@ -1,20 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from lingua.args import dataclass_from_dict
|
||||
from lingua.tokenizers.abstract_tokenizer import Tokenizer
|
||||
from lingua.tokenizers.build_tokenizer import build_tokenizer
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
from tqdm import tqdm
|
||||
|
||||
from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs
|
||||
from bytelatent.base_transformer import (
|
||||
Attention,
|
||||
causal_mask,
|
||||
|
@ -23,7 +19,10 @@ from bytelatent.base_transformer import (
|
|||
lengths_to_start_ids,
|
||||
)
|
||||
from bytelatent.checkpoint import CONSOLIDATE_NAME
|
||||
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||
from bytelatent.transformer import LMTransformer
|
||||
|
||||
|
||||
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||
|
@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
|
|||
return next_token.view(shape[:-1])
|
||||
|
||||
|
||||
def pack_prompts(prompts: List[int]):
|
||||
def pack_prompts(prompts: list[int]):
|
||||
res = []
|
||||
lengths = []
|
||||
for i, p in enumerate(prompts):
|
||||
|
@ -120,22 +119,6 @@ class KVCache(nn.Module):
|
|||
return self.k_cache, self.v_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class PackedCausalTransformerGeneratorArgs:
|
||||
temperature: float = 0.0
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[float] = None
|
||||
max_gen_len: int = 512 # Maximum number of tokens to generate
|
||||
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
|
||||
max_prompt_len: Optional[int] = None
|
||||
until: List[str] = field(default_factory=list)
|
||||
compile_prefilling: bool = False
|
||||
reduce_generation_overhead: bool = False
|
||||
show_progress: bool = False
|
||||
dtype: Optional[str] = "bf16"
|
||||
device: Optional[str] = "cuda"
|
||||
|
||||
|
||||
class PackedCausalTransformerGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator:
|
|||
|
||||
def load_consolidated_model_and_tokenizer(
|
||||
consolidated_path,
|
||||
model_cls=LMTransformer,
|
||||
model_args_cls=LMTransformerArgs,
|
||||
):
|
||||
ckpt_path = Path(consolidated_path)
|
||||
config = ckpt_path / "params.json"
|
||||
config = OmegaConf.load(config)
|
||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||
fs = get_fs(train_args_path)
|
||||
with fs.open(train_args_path) as f:
|
||||
train_args = TrainArgs.model_validate_json(f.read())
|
||||
|
||||
if train_args.train_entropy_model:
|
||||
model_args = train_args.entropy_model
|
||||
model = LMTransformer(model_args)
|
||||
else:
|
||||
model_args = train_args.model
|
||||
model = ByteLatentTransformer(model_args)
|
||||
|
||||
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
||||
config.distributed.model_dtype
|
||||
train_args.distributed.model_dtype
|
||||
]
|
||||
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
|
||||
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
|
||||
model = model_cls(model_args)
|
||||
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
|
||||
tokenizer = train_args.data.tokenizer_args.build()
|
||||
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
|
||||
model.load_state_dict(st_dict["model"])
|
||||
model = model.cuda().eval()
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(dtype=param_dtype)
|
||||
return model, tokenizer, config
|
||||
return model, tokenizer, train_args
|
||||
|
||||
|
||||
def main():
|
|
@ -10,7 +10,7 @@ from copy import deepcopy
|
|||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from timeit import default_timer as timer
|
||||
from typing import Any, Dict, Type, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor
|
|||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from bytelatent.args import TrainArgs
|
||||
from bytelatent.args import TrainArgs, parse_args
|
||||
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
||||
from bytelatent.data.data_types import DataLoaderState
|
||||
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||
MultiprocessIterator,
|
||||
MultiprocessIteratorState,
|
||||
)
|
||||
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
||||
from bytelatent.distributed import (
|
||||
check_model_value_range,
|
||||
clean_env,
|
||||
|
@ -39,6 +43,7 @@ from bytelatent.distributed import (
|
|||
setup_env,
|
||||
setup_torch_distributed,
|
||||
)
|
||||
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
||||
from bytelatent.logger import init_logger
|
||||
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
|
@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"):
|
|||
return dict(items)
|
||||
|
||||
|
||||
def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
|
||||
"""
|
||||
Converts a dictionary to a dataclass instance, recursively for nested structures.
|
||||
"""
|
||||
base = OmegaConf.structured(cls())
|
||||
OmegaConf.set_struct(base, strict)
|
||||
override = OmegaConf.create(data)
|
||||
return OmegaConf.to_object(OmegaConf.merge(base, override))
|
||||
def get_iterator_state_name(iterator_state):
|
||||
if isinstance(iterator_state, MultiprocessIteratorState):
|
||||
return "multiprocess"
|
||||
elif isinstance(iterator_state, PackingIteratorState):
|
||||
return "packing"
|
||||
else:
|
||||
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
||||
|
||||
|
||||
# TODO: Make this pydantic based instead of data class based
|
||||
# TODO: Generalize this to any iterator state
|
||||
@dataclass
|
||||
class TrainState(Stateful):
|
||||
step: int # Nb of steps taken by the optimizer
|
||||
acc_step: int # Nb of accumulation steps done since last optimizer step
|
||||
scheduler: lr_scheduler.LambdaLR
|
||||
data_loader_state: DataLoaderState
|
||||
data_loader_state: MultiprocessIteratorState | PackingIteratorState
|
||||
scale: float = 1.0
|
||||
data_loader_class: str | None = None
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
"acc_step": self.acc_step,
|
||||
"data_loader_state": self.data_loader_state.dict(),
|
||||
"data_loader_state": self.data_loader_state.model_dump(),
|
||||
"data_loader_class": get_iterator_state_name(self.data_loader_state),
|
||||
"scheduler": self.scheduler.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.step = state_dict["step"]
|
||||
self.acc_step = state_dict["acc_step"]
|
||||
self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"])
|
||||
self.data_loader_class = state_dict["data_loader_class"]
|
||||
if self.data_loader_class == "multiprocess":
|
||||
self.data_loader_state = MultiprocessIteratorState(
|
||||
**state_dict["data_loader_state"]
|
||||
)
|
||||
elif self.data_loader_class == "packing":
|
||||
self.data_loader_state = PackingIteratorState(
|
||||
**state_dict["data_loader_state"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"invalid data loader class: {self.data_loader_class}")
|
||||
self.scheduler.load_state_dict(state_dict["scheduler"])
|
||||
|
||||
|
||||
|
@ -345,7 +363,10 @@ def train(args: TrainArgs):
|
|||
nwords_since_last_log = 0
|
||||
time_last_log = timer()
|
||||
gc.collect()
|
||||
while train_state.step < args.steps:
|
||||
saved = False
|
||||
while train_state.step < args.steps and (
|
||||
args.max_steps is None or train_state.step < args.max_steps
|
||||
):
|
||||
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
|
||||
train_state.acc_step += 1
|
||||
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
|
||||
|
@ -552,7 +573,6 @@ def train(args: TrainArgs):
|
|||
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
||||
)
|
||||
|
||||
saved = False
|
||||
if every_n_steps(
|
||||
train_state, args.checkpoint.dump.every, acc_step=0
|
||||
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||
|
@ -567,18 +587,14 @@ def train(args: TrainArgs):
|
|||
if args.eval is not None and every_n_steps(
|
||||
train_state, args.checkpoint.eval.every, acc_step=0
|
||||
):
|
||||
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
||||
|
||||
eval_args = dataclass_from_dict(EvalArgs, args.eval)
|
||||
eval_args = args.eval
|
||||
|
||||
eval_args.global_step = train_state.step
|
||||
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
|
||||
eval_args.dump_dir = str(
|
||||
os.path.join(
|
||||
args.dump_dir,
|
||||
"evals",
|
||||
EVAL_FOLDER_NAME.format(train_state.step),
|
||||
)
|
||||
eval_args.dump_dir = os.path.join(
|
||||
args.dump_dir,
|
||||
"evals",
|
||||
EVAL_FOLDER_NAME.format(train_state.step),
|
||||
)
|
||||
eval_args.metric_log_dir = args.dump_dir
|
||||
if args.async_eval_gpus is None:
|
||||
|
@ -619,6 +635,9 @@ def train(args: TrainArgs):
|
|||
args,
|
||||
device_mesh=world_mesh,
|
||||
)
|
||||
if isinstance(data_loader, MultiprocessIterator):
|
||||
logger.info("Closing MP iterator before exiting")
|
||||
data_loader.shutdown()
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
@ -661,15 +680,7 @@ def main():
|
|||
|
||||
Plus all the default values in TrainArgs dataclass.
|
||||
"""
|
||||
cli_args = OmegaConf.from_cli()
|
||||
file_cfg = OmegaConf.load(cli_args.config)
|
||||
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
||||
del cli_args.config
|
||||
|
||||
default_cfg = OmegaConf.create(TrainArgs().model_dump())
|
||||
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||||
train_args = TrainArgs.model_validate(cfg)
|
||||
train_args = parse_args(TrainArgs)
|
||||
if train_args.debug_dynamo:
|
||||
import torch._dynamo
|
||||
|
||||
|
|
Loading…
Reference in a new issue