mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-31 01:52:15 +00:00
This includes fixes that make checkpointing and reloading work correctly. (#35)
It also batches in a first set of changes for fixing eval code Summary: Test Plan:
This commit is contained in:
parent
7622d28b74
commit
7044771a12
|
@ -544,7 +544,7 @@ def train(args: TrainArgs):
|
||||||
if args.eval is not None and every_n_steps(
|
if args.eval is not None and every_n_steps(
|
||||||
train_state, args.checkpoint.eval.every, acc_step=0
|
train_state, args.checkpoint.eval.every, acc_step=0
|
||||||
):
|
):
|
||||||
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
||||||
|
|
||||||
eval_args = dataclass_from_dict(EvalArgs, args.eval)
|
eval_args = dataclass_from_dict(EvalArgs, args.eval)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
|
from omegaconf import OmegaConf
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent.checkpoint import CheckpointArgs
|
from bytelatent.checkpoint import CheckpointArgs
|
||||||
|
@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
|
||||||
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
|
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(args_cls):
|
||||||
|
cli_args = OmegaConf.from_cli()
|
||||||
|
file_cfg = OmegaConf.load(cli_args.config)
|
||||||
|
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
||||||
|
del cli_args.config
|
||||||
|
|
||||||
|
default_cfg = OmegaConf.create(args_cls().model_dump())
|
||||||
|
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||||
|
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||||||
|
pydantic_args = args_cls.model_validate(cfg)
|
||||||
|
return pydantic_args
|
||||||
|
|
||||||
|
|
||||||
def distribute_data_to_rank(
|
def distribute_data_to_rank(
|
||||||
*,
|
*,
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
|
@ -71,6 +85,22 @@ def distribute_data_to_rank(
|
||||||
return rank_to_arrow_iterator_params[rank]
|
return rank_to_arrow_iterator_params[rank]
|
||||||
|
|
||||||
|
|
||||||
|
class PackedCausalTransformerGeneratorArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
temperature: float = 0.0
|
||||||
|
top_p: float | None = None
|
||||||
|
top_k: float | None = None
|
||||||
|
max_gen_len: int = 512 # Maximum number of tokens to generate
|
||||||
|
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
|
||||||
|
max_prompt_len: int | None = None
|
||||||
|
until: list[str] = []
|
||||||
|
compile_prefilling: bool = False
|
||||||
|
reduce_generation_overhead: bool = False
|
||||||
|
show_progress: bool = False
|
||||||
|
dtype: str | None = "bf16"
|
||||||
|
device: str | None = "cuda"
|
||||||
|
|
||||||
|
|
||||||
class DataloaderArgs(BaseModel):
|
class DataloaderArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
s3_profile: str | None = None
|
s3_profile: str | None = None
|
||||||
|
@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel):
|
||||||
return packing_iterator
|
return packing_iterator
|
||||||
|
|
||||||
|
|
||||||
|
class LMHarnessArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
tasks: list[Any] | None = None
|
||||||
|
num_fewshot: int | None = None
|
||||||
|
device: str | None = None
|
||||||
|
use_cache: str | None = None
|
||||||
|
cache_requests: bool = False
|
||||||
|
rewrite_requests_cache: bool = False
|
||||||
|
delete_requests_cache: bool = False
|
||||||
|
limit: int | float | None = None
|
||||||
|
bootstrap_iters: int = 100000
|
||||||
|
check_integrity: bool = False
|
||||||
|
write_out: bool = False
|
||||||
|
log_samples: bool = True
|
||||||
|
system_instruction: str | None = None
|
||||||
|
apply_chat_template: bool | str = False
|
||||||
|
fewshot_as_multiturn: bool = False
|
||||||
|
gen_kwargs: str | None = None
|
||||||
|
verbosity: str = "INFO"
|
||||||
|
predict_only: bool = False
|
||||||
|
random_seed: int = 0
|
||||||
|
numpy_random_seed: int = 1234
|
||||||
|
torch_random_seed: int = 1234
|
||||||
|
fewshot_random_seed: int = 1234
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
max_steps: int | None = (
|
||||||
|
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
||||||
|
)
|
||||||
|
use_val_from_train_src: bool = True # Use the validation set from training sources
|
||||||
|
root_dir: str = ""
|
||||||
|
sources: list[str] = [] # Other sources to eval on
|
||||||
|
|
||||||
|
|
||||||
|
class EvalArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
dump_dir: str
|
||||||
|
ckpt_dir: str
|
||||||
|
metric_log_dir: str | None = None
|
||||||
|
generator: PackedCausalTransformerGeneratorArgs = (
|
||||||
|
PackedCausalTransformerGeneratorArgs()
|
||||||
|
)
|
||||||
|
|
||||||
|
harness: LMHarnessArgs | None = LMHarnessArgs()
|
||||||
|
validation: ValidationArgs | None = ValidationArgs()
|
||||||
|
|
||||||
|
global_step: int | None = None # for in-training evaluation
|
||||||
|
s3_profile: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class TrainArgs(BaseModel):
|
class TrainArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
name: str = "lingua"
|
name: str = "lingua"
|
||||||
|
@ -186,6 +268,9 @@ class TrainArgs(BaseModel):
|
||||||
|
|
||||||
# Nb optimizer steps to take
|
# Nb optimizer steps to take
|
||||||
steps: int = 1000
|
steps: int = 1000
|
||||||
|
# If not None, halt training after this many steps,
|
||||||
|
# useful for debugging
|
||||||
|
max_steps: int | None = None
|
||||||
|
|
||||||
data: DataloaderArgs = DataloaderArgs()
|
data: DataloaderArgs = DataloaderArgs()
|
||||||
optim: OptimArgs = OptimArgs()
|
optim: OptimArgs = OptimArgs()
|
||||||
|
@ -203,7 +288,7 @@ class TrainArgs(BaseModel):
|
||||||
|
|
||||||
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
|
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
|
||||||
async_eval_gpus: int | None = None
|
async_eval_gpus: int | None = None
|
||||||
eval: Any | None = None
|
eval: EvalArgs | None = None
|
||||||
eval_on_gpus: int | None = None
|
eval_on_gpus: int | None = None
|
||||||
|
|
||||||
def dump_to_yaml_file(
|
def dump_to_yaml_file(
|
||||||
|
|
|
@ -7,6 +7,7 @@ import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
|
@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import (
|
||||||
set_state_dict,
|
set_state_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.distributed import get_is_master
|
from bytelatent.distributed import get_is_master
|
||||||
|
|
||||||
logger = logging.getLogger("CHECKPOINT")
|
logger = logging.getLogger("CHECKPOINT")
|
||||||
|
@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel):
|
||||||
path: str | None = None
|
path: str | None = None
|
||||||
init_ckpt_path: str | None = None
|
init_ckpt_path: str | None = None
|
||||||
continue_training_from_init: bool = False
|
continue_training_from_init: bool = False
|
||||||
|
s3_profile: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def _get_key_step(name: str):
|
def _get_key_step(name: str):
|
||||||
return int(re.findall(RE_DIGITS, name)[-1])
|
return int(re.findall(RE_DIGITS, name)[-1])
|
||||||
|
|
||||||
|
|
||||||
def consolidate_checkpoints(ckpt_dir: str):
|
def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
|
||||||
"""
|
"""
|
||||||
Consolidates all FSDP checkpoints in a directory to a single file
|
Consolidates all FSDP checkpoints in a directory to a single file
|
||||||
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
|
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
|
||||||
|
@ -102,15 +105,17 @@ def load_from_checkpoint(
|
||||||
dcp.load(state_dict, checkpoint_id=ckpt_dir)
|
dcp.load(state_dict, checkpoint_id=ckpt_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Rewrite the file operations here to use fsspec to enable s3 writing.
|
||||||
class CheckpointManager:
|
class CheckpointManager:
|
||||||
def __init__(self, args: CheckpointArgs):
|
def __init__(self, args: CheckpointArgs):
|
||||||
self.path = args.path
|
self.path = args.path
|
||||||
|
self.fs = get_fs(self.path, s3_profile=args.s3_profile)
|
||||||
self.dump_every = args.dump
|
self.dump_every = args.dump
|
||||||
self.eval_every = args.eval
|
self.eval_every = args.eval
|
||||||
self.init_ckpt_path = args.init_ckpt_path
|
self.init_ckpt_path = args.init_ckpt_path
|
||||||
self.continue_training_from_init = args.continue_training_from_init
|
self.continue_training_from_init = args.continue_training_from_init
|
||||||
|
|
||||||
assert os.path.exists(
|
assert self.fs.exists(
|
||||||
self.path
|
self.path
|
||||||
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
|
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
|
||||||
|
|
||||||
|
|
|
@ -98,11 +98,4 @@ logging:
|
||||||
freq: 10
|
freq: 10
|
||||||
|
|
||||||
eval_on_gpus: 8
|
eval_on_gpus: 8
|
||||||
eval:
|
eval: null
|
||||||
dataset_dir: /checkpoint/amaia/codegen/datasets/eval
|
|
||||||
tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
|
|
||||||
generator:
|
|
||||||
max_tokens: 65536
|
|
||||||
dtype: bf16
|
|
||||||
|
|
||||||
mp_size: 1
|
|
||||||
|
|
|
@ -72,11 +72,4 @@ logging:
|
||||||
freq: 10
|
freq: 10
|
||||||
|
|
||||||
eval_on_gpus: 8
|
eval_on_gpus: 8
|
||||||
eval:
|
eval: null
|
||||||
dataset_dir: ???
|
|
||||||
tasks: ???
|
|
||||||
generator:
|
|
||||||
max_tokens: 65536
|
|
||||||
dtype: bf16
|
|
||||||
|
|
||||||
mp_size: 1
|
|
||||||
|
|
|
@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel):
|
||||||
n_views: int = 2
|
n_views: int = 2
|
||||||
|
|
||||||
|
|
||||||
class DataLoaderState(BaseModel):
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
|
||||||
multi_choice_state: MultiChoiceState
|
|
||||||
pack_tokens_state: BltPackTokensState
|
|
||||||
prefetch_state: PrefetchState
|
|
||||||
|
|
||||||
|
|
||||||
BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
|
|
||||||
|
|
||||||
|
|
||||||
class BltSequence(BaseModel):
|
class BltSequence(BaseModel):
|
||||||
tokens: list[int]
|
tokens: list[int]
|
||||||
mask: list[bool]
|
mask: list[bool]
|
||||||
|
|
|
@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
self.producer = None
|
self.producer = None
|
||||||
self.stop_iterating_event = None
|
self.stop_iterating_event = None
|
||||||
self.state_dumped_event = None
|
self.state_dumped_event = None
|
||||||
|
self.force_shutdown = False
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
if self.producer is not None:
|
||||||
|
# This properly shuts things down
|
||||||
|
self.producer.kill()
|
||||||
|
self.force_shutdown = True
|
||||||
|
|
||||||
def get_state(self) -> MultiprocessIteratorState:
|
def get_state(self) -> MultiprocessIteratorState:
|
||||||
"""
|
"""
|
||||||
|
@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
to halt the background process and allow it to write the state to the main loop
|
to halt the background process and allow it to write the state to the main loop
|
||||||
in order to not lose data
|
in order to not lose data
|
||||||
"""
|
"""
|
||||||
|
if self.force_shutdown:
|
||||||
|
raise ValueError(
|
||||||
|
"State will be invalid if shutdown was forced before state persisted."
|
||||||
|
)
|
||||||
if self.producer is None:
|
if self.producer is None:
|
||||||
serialized_prefetch_buffer = json.dumps(
|
serialized_prefetch_buffer = json.dumps(
|
||||||
[b.to_python_dict() for b in self.prefetch_buffer]
|
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||||
|
@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_iter(self):
|
def create_iter(self):
|
||||||
|
if self.force_shutdown:
|
||||||
|
raise ValueError(
|
||||||
|
"Iterator may be invalid if shutdown was forced before state persisted."
|
||||||
|
)
|
||||||
logging.info("Main thread: Creating MP iterator")
|
logging.info("Main thread: Creating MP iterator")
|
||||||
# First yield from the stored prefetch buffer.
|
# First yield from the stored prefetch buffer.
|
||||||
if self.prefetch_buffer is not None:
|
if self.prefetch_buffer is not None:
|
||||||
|
|
|
@ -4,20 +4,20 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import asdict, dataclass, field
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lingua.args import dump_config
|
|
||||||
from lingua.data import init_choice_state, setup_sources
|
|
||||||
from lm_eval import simple_evaluate
|
from lm_eval import simple_evaluate
|
||||||
from lm_eval.api.instance import Instance
|
from lm_eval.api.instance import Instance
|
||||||
from lm_eval.api.model import LM
|
from lm_eval.api.model import LM
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
|
||||||
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||||
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.distributed import (
|
from bytelatent.distributed import (
|
||||||
DistributedArgs,
|
DistributedArgs,
|
||||||
dist_mean_dict,
|
dist_mean_dict,
|
||||||
|
@ -25,72 +25,17 @@ from bytelatent.distributed import (
|
||||||
get_world_size,
|
get_world_size,
|
||||||
setup_torch_distributed,
|
setup_torch_distributed,
|
||||||
)
|
)
|
||||||
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
from bytelatent.generate import (
|
||||||
|
|
||||||
from apps.main.generate import (
|
|
||||||
PackedCausalTransformerGenerator,
|
PackedCausalTransformerGenerator,
|
||||||
PackedCausalTransformerGeneratorArgs,
|
|
||||||
load_consolidated_model_and_tokenizer,
|
load_consolidated_model_and_tokenizer,
|
||||||
)
|
)
|
||||||
|
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
||||||
|
|
||||||
EVAL_FOLDER_NAME = "{:010d}"
|
EVAL_FOLDER_NAME = "{:010d}"
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LMHarnessArgs:
|
|
||||||
tasks: Optional[List[Any]] = None
|
|
||||||
num_fewshot: Optional[int] = None
|
|
||||||
device: Optional[str] = None
|
|
||||||
use_cache: Optional[str] = None
|
|
||||||
cache_requests: bool = False
|
|
||||||
rewrite_requests_cache: bool = False
|
|
||||||
delete_requests_cache: bool = False
|
|
||||||
limit: Optional[Union[int, float]] = None
|
|
||||||
bootstrap_iters: int = 100000
|
|
||||||
check_integrity: bool = False
|
|
||||||
write_out: bool = False
|
|
||||||
log_samples: bool = True
|
|
||||||
system_instruction: Optional[str] = None
|
|
||||||
apply_chat_template: Union[bool, str] = False
|
|
||||||
fewshot_as_multiturn: bool = False
|
|
||||||
gen_kwargs: Optional[str] = None
|
|
||||||
verbosity: str = "INFO"
|
|
||||||
predict_only: bool = False
|
|
||||||
random_seed: int = 0
|
|
||||||
numpy_random_seed: int = 1234
|
|
||||||
torch_random_seed: int = 1234
|
|
||||||
fewshot_random_seed: int = 1234
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ValidationArgs:
|
|
||||||
max_steps: Optional[int] = (
|
|
||||||
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
|
||||||
)
|
|
||||||
use_val_from_train_src: bool = True # Use the validation set from training sources
|
|
||||||
root_dir: str = ""
|
|
||||||
sources: List[str] = field(default_factory=list) # Other sources to eval on
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EvalArgs:
|
|
||||||
name: str = "evals"
|
|
||||||
dump_dir: Optional[str] = None
|
|
||||||
metric_log_dir: Optional[str] = None
|
|
||||||
ckpt_dir: str = ""
|
|
||||||
generator: PackedCausalTransformerGeneratorArgs = field(
|
|
||||||
default_factory=PackedCausalTransformerGeneratorArgs
|
|
||||||
)
|
|
||||||
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
|
|
||||||
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
|
|
||||||
|
|
||||||
wandb: Optional[Any] = None
|
|
||||||
|
|
||||||
global_step: Optional[int] = None # for in-training evaluation
|
|
||||||
|
|
||||||
|
|
||||||
def all_dicts_same(dict_list):
|
def all_dicts_same(dict_list):
|
||||||
if not dict_list: # Check if the list is empty
|
if not dict_list: # Check if the list is empty
|
||||||
return True
|
return True
|
||||||
|
@ -120,7 +65,7 @@ class EvalHarnessLM(LM):
|
||||||
self._world_size = get_world_size()
|
self._world_size = get_world_size()
|
||||||
self.device = generator.device
|
self.device = generator.device
|
||||||
|
|
||||||
def generate_until(self, requests: List[Instance]) -> List[str]:
|
def generate_until(self, requests: list[Instance]) -> list[str]:
|
||||||
prompts, gen_args = zip(*[req.args for req in requests])
|
prompts, gen_args = zip(*[req.args for req in requests])
|
||||||
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
|
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
|
||||||
gen_args = gen_args[0]
|
gen_args = gen_args[0]
|
||||||
|
@ -141,7 +86,7 @@ class EvalHarnessLM(LM):
|
||||||
filtered_gen.append(g)
|
filtered_gen.append(g)
|
||||||
return filtered_gen
|
return filtered_gen
|
||||||
|
|
||||||
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
||||||
prompts, continuations = zip(*[req.args for req in requests])
|
prompts, continuations = zip(*[req.args for req in requests])
|
||||||
inputs = [req.args[0] + req.args[1] for req in requests]
|
inputs = [req.args[0] + req.args[1] for req in requests]
|
||||||
max_gen_len = self.generator.max_gen_len
|
max_gen_len = self.generator.max_gen_len
|
||||||
|
@ -158,7 +103,7 @@ class EvalHarnessLM(LM):
|
||||||
self.generator.max_gen_len = max_gen_len
|
self.generator.max_gen_len = max_gen_len
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
|
||||||
prompts = [req.args[0] for req in requests]
|
prompts = [req.args[0] for req in requests]
|
||||||
max_gen_len = self.generator.max_gen_len
|
max_gen_len = self.generator.max_gen_len
|
||||||
# We temporarily lower max gen len
|
# We temporarily lower max gen len
|
||||||
|
@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||||
return all_val_metrics
|
return all_val_metrics
|
||||||
|
|
||||||
|
|
||||||
def launch_eval(cfg: EvalArgs):
|
def launch_eval(eval_args: EvalArgs):
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
setup_torch_distributed(DistributedArgs())
|
setup_torch_distributed(DistributedArgs())
|
||||||
|
|
||||||
|
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
|
||||||
if (
|
if (
|
||||||
Path(cfg.ckpt_dir).exists()
|
fs.exists(eval_args.ckpt_dir)
|
||||||
and (Path(cfg.ckpt_dir) / "params.json").exists()
|
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
|
||||||
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
|
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
|
||||||
):
|
):
|
||||||
consolidate_path = Path(cfg.ckpt_dir)
|
consolidate_path = eval_args.ckpt_dir
|
||||||
else:
|
else:
|
||||||
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
|
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||||
if not consolidate_path.exists() and get_global_rank() == 0:
|
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||||
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
|
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
||||||
|
|
||||||
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
|
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||||
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
|
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||||
|
f.write(eval_args.model_dump_json())
|
||||||
|
|
||||||
consolidate_path = str(consolidate_path)
|
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
logger.info("Loading model")
|
logger.info("Loading model")
|
||||||
|
# TODO: Make this general so that it works with either
|
||||||
|
# LMTransformer or Blt, similar with args
|
||||||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||||
consolidate_path,
|
consolidate_path,
|
||||||
model_cls=LMTransformer,
|
|
||||||
model_args_cls=LMTransformerArgs,
|
|
||||||
)
|
)
|
||||||
logger.info("Model loaded")
|
logger.info("Model loaded")
|
||||||
model.eval()
|
model.eval()
|
||||||
generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
|
generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
|
||||||
|
|
||||||
wrap = EvalHarnessLM(generator)
|
wrap = EvalHarnessLM(generator)
|
||||||
results = simple_evaluate(wrap, **asdict(cfg.harness))
|
# Redo
|
||||||
|
results = simple_evaluate(wrap, eval_args.harness.model_dump())
|
||||||
val_results = None
|
val_results = None
|
||||||
if cfg.validation:
|
if eval_args.validation:
|
||||||
val_results = eval_on_val(generator, cfg.validation, train_cfg)
|
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
||||||
if get_global_rank() == 0:
|
if get_global_rank() == 0:
|
||||||
with open(Path(cfg.dump_dir) / "results.json", "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
||||||
f.write(json.dumps(results))
|
f.write(json.dumps(results))
|
||||||
logger.info(f"All evaluation results: {results['results']}")
|
logger.info(f"All evaluation results: {results['results']}")
|
||||||
if val_results is not None:
|
if val_results is not None:
|
||||||
with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
||||||
f.write(json.dumps(val_results))
|
f.write(json.dumps(val_results))
|
||||||
logger.info(f"All validation results: {val_results}")
|
logger.info(f"All validation results: {val_results}")
|
||||||
if cfg.metric_log_dir and get_global_rank() == 0:
|
if eval_args.metric_log_dir and get_global_rank() == 0:
|
||||||
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
|
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
||||||
|
|
||||||
logger.info(f"Writing metric logs to {metric_log_path}")
|
logger.info(f"Writing metric logs to {metric_log_path}")
|
||||||
timestamp = {
|
timestamp = {
|
||||||
"created_at": datetime.utcnow().isoformat(),
|
"created_at": datetime.utcnow().isoformat(),
|
||||||
}
|
}
|
||||||
if cfg.global_step is not None:
|
if eval_args.global_step is not None:
|
||||||
timestamp["global_step"] = cfg.global_step
|
timestamp["global_step"] = eval_args.global_step
|
||||||
print(
|
print(
|
||||||
json.dumps(timestamp | results["results"]),
|
json.dumps(timestamp | results["results"]),
|
||||||
file=open(metric_log_path, mode="a"),
|
file=fs.open(metric_log_path, mode="a"),
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
|
val_log_path = os.path.join(
|
||||||
|
eval_args.metric_log_dir, "metrics.validation.jsonl"
|
||||||
|
)
|
||||||
if val_results is not None:
|
if val_results is not None:
|
||||||
print(
|
print(
|
||||||
json.dumps(timestamp | val_results),
|
json.dumps(timestamp | val_results),
|
||||||
file=open(val_log_path, mode="a"),
|
file=fs.open(val_log_path, mode="a"),
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
eval_args = parse_args(EvalArgs)
|
||||||
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
launch_eval(eval_args)
|
||||||
This accepts arguments as a dot list
|
|
||||||
So if the dataclass looks like
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DummyArgs:
|
|
||||||
name: str
|
|
||||||
model: LMTransformerArgsgs
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LMTransformerArgsgs:
|
|
||||||
dim: int
|
|
||||||
|
|
||||||
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
|
||||||
or just name=tictac for top level attributes.
|
|
||||||
|
|
||||||
The behavior here is as follows:
|
|
||||||
1. We instantiate EvalArgs with its default values
|
|
||||||
2. We override those default values with the ones in the provided config file
|
|
||||||
3. We override the result with the additional arguments provided through command line
|
|
||||||
|
|
||||||
For example, if the config is the following
|
|
||||||
|
|
||||||
model:
|
|
||||||
dim: 128
|
|
||||||
n_layers: 4
|
|
||||||
|
|
||||||
and you call eval.py with eval.py model.dim=64
|
|
||||||
|
|
||||||
Then the final TrainArgs will have
|
|
||||||
|
|
||||||
model:
|
|
||||||
dim: 64
|
|
||||||
n_layers: 4
|
|
||||||
|
|
||||||
Plus all the default values in EvalArgs dataclass.
|
|
||||||
"""
|
|
||||||
cli_args = OmegaConf.from_cli()
|
|
||||||
file_cfg = OmegaConf.load(cli_args.config)
|
|
||||||
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
|
||||||
del cli_args.config
|
|
||||||
|
|
||||||
default_cfg = OmegaConf.structured(EvalArgs())
|
|
||||||
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
|
||||||
cfg = OmegaConf.to_object(cfg)
|
|
||||||
launch_eval(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
|
@ -1,20 +1,16 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lingua.args import dataclass_from_dict
|
|
||||||
from lingua.tokenizers.abstract_tokenizer import Tokenizer
|
|
||||||
from lingua.tokenizers.build_tokenizer import build_tokenizer
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.attention.flex_attention import create_block_mask
|
from torch.nn.attention.flex_attention import create_block_mask
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
Attention,
|
Attention,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
@ -23,7 +19,10 @@ from bytelatent.base_transformer import (
|
||||||
lengths_to_start_ids,
|
lengths_to_start_ids,
|
||||||
)
|
)
|
||||||
from bytelatent.checkpoint import CONSOLIDATE_NAME
|
from bytelatent.checkpoint import CONSOLIDATE_NAME
|
||||||
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
from bytelatent.data.file_util import get_fs
|
||||||
|
from bytelatent.model.blt import ByteLatentTransformer
|
||||||
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
|
from bytelatent.transformer import LMTransformer
|
||||||
|
|
||||||
|
|
||||||
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||||
|
@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
|
||||||
return next_token.view(shape[:-1])
|
return next_token.view(shape[:-1])
|
||||||
|
|
||||||
|
|
||||||
def pack_prompts(prompts: List[int]):
|
def pack_prompts(prompts: list[int]):
|
||||||
res = []
|
res = []
|
||||||
lengths = []
|
lengths = []
|
||||||
for i, p in enumerate(prompts):
|
for i, p in enumerate(prompts):
|
||||||
|
@ -120,22 +119,6 @@ class KVCache(nn.Module):
|
||||||
return self.k_cache, self.v_cache
|
return self.k_cache, self.v_cache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PackedCausalTransformerGeneratorArgs:
|
|
||||||
temperature: float = 0.0
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
top_k: Optional[float] = None
|
|
||||||
max_gen_len: int = 512 # Maximum number of tokens to generate
|
|
||||||
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
|
|
||||||
max_prompt_len: Optional[int] = None
|
|
||||||
until: List[str] = field(default_factory=list)
|
|
||||||
compile_prefilling: bool = False
|
|
||||||
reduce_generation_overhead: bool = False
|
|
||||||
show_progress: bool = False
|
|
||||||
dtype: Optional[str] = "bf16"
|
|
||||||
device: Optional[str] = "cuda"
|
|
||||||
|
|
||||||
|
|
||||||
class PackedCausalTransformerGenerator:
|
class PackedCausalTransformerGenerator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator:
|
||||||
|
|
||||||
def load_consolidated_model_and_tokenizer(
|
def load_consolidated_model_and_tokenizer(
|
||||||
consolidated_path,
|
consolidated_path,
|
||||||
model_cls=LMTransformer,
|
|
||||||
model_args_cls=LMTransformerArgs,
|
|
||||||
):
|
):
|
||||||
ckpt_path = Path(consolidated_path)
|
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||||
config = ckpt_path / "params.json"
|
fs = get_fs(train_args_path)
|
||||||
config = OmegaConf.load(config)
|
with fs.open(train_args_path) as f:
|
||||||
|
train_args = TrainArgs.model_validate_json(f.read())
|
||||||
|
|
||||||
|
if train_args.train_entropy_model:
|
||||||
|
model_args = train_args.entropy_model
|
||||||
|
model = LMTransformer(model_args)
|
||||||
|
else:
|
||||||
|
model_args = train_args.model
|
||||||
|
model = ByteLatentTransformer(model_args)
|
||||||
|
|
||||||
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
||||||
config.distributed.model_dtype
|
train_args.distributed.model_dtype
|
||||||
]
|
]
|
||||||
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
|
tokenizer = train_args.data.tokenizer_args.build()
|
||||||
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
|
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
|
||||||
model = model_cls(model_args)
|
|
||||||
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
|
|
||||||
model.load_state_dict(st_dict["model"])
|
model.load_state_dict(st_dict["model"])
|
||||||
model = model.cuda().eval()
|
model = model.cuda().eval()
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
param.data = param.data.to(dtype=param_dtype)
|
param.data = param.data.to(dtype=param_dtype)
|
||||||
return model, tokenizer, config
|
return model, tokenizer, train_args
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
|
@ -10,7 +10,7 @@ from copy import deepcopy
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from typing import Any, Dict, Type, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor
|
||||||
from torch.distributed.checkpoint.stateful import Stateful
|
from torch.distributed.checkpoint.stateful import Stateful
|
||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
from bytelatent.args import TrainArgs
|
from bytelatent.args import TrainArgs, parse_args
|
||||||
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
||||||
from bytelatent.data.data_types import DataLoaderState
|
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||||
|
MultiprocessIterator,
|
||||||
|
MultiprocessIteratorState,
|
||||||
|
)
|
||||||
|
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
||||||
from bytelatent.distributed import (
|
from bytelatent.distributed import (
|
||||||
check_model_value_range,
|
check_model_value_range,
|
||||||
clean_env,
|
clean_env,
|
||||||
|
@ -39,6 +43,7 @@ from bytelatent.distributed import (
|
||||||
setup_env,
|
setup_env,
|
||||||
setup_torch_distributed,
|
setup_torch_distributed,
|
||||||
)
|
)
|
||||||
|
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
||||||
from bytelatent.logger import init_logger
|
from bytelatent.logger import init_logger
|
||||||
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
|
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
|
||||||
from bytelatent.model.blt import ByteLatentTransformer
|
from bytelatent.model.blt import ByteLatentTransformer
|
||||||
|
@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"):
|
||||||
return dict(items)
|
return dict(items)
|
||||||
|
|
||||||
|
|
||||||
def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
|
def get_iterator_state_name(iterator_state):
|
||||||
"""
|
if isinstance(iterator_state, MultiprocessIteratorState):
|
||||||
Converts a dictionary to a dataclass instance, recursively for nested structures.
|
return "multiprocess"
|
||||||
"""
|
elif isinstance(iterator_state, PackingIteratorState):
|
||||||
base = OmegaConf.structured(cls())
|
return "packing"
|
||||||
OmegaConf.set_struct(base, strict)
|
else:
|
||||||
override = OmegaConf.create(data)
|
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
||||||
return OmegaConf.to_object(OmegaConf.merge(base, override))
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Make this pydantic based instead of data class based
|
||||||
|
# TODO: Generalize this to any iterator state
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainState(Stateful):
|
class TrainState(Stateful):
|
||||||
step: int # Nb of steps taken by the optimizer
|
step: int # Nb of steps taken by the optimizer
|
||||||
acc_step: int # Nb of accumulation steps done since last optimizer step
|
acc_step: int # Nb of accumulation steps done since last optimizer step
|
||||||
scheduler: lr_scheduler.LambdaLR
|
scheduler: lr_scheduler.LambdaLR
|
||||||
data_loader_state: DataLoaderState
|
data_loader_state: MultiprocessIteratorState | PackingIteratorState
|
||||||
scale: float = 1.0
|
scale: float = 1.0
|
||||||
|
data_loader_class: str | None = None
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, Any]:
|
def state_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"step": self.step,
|
"step": self.step,
|
||||||
"acc_step": self.acc_step,
|
"acc_step": self.acc_step,
|
||||||
"data_loader_state": self.data_loader_state.dict(),
|
"data_loader_state": self.data_loader_state.model_dump(),
|
||||||
|
"data_loader_class": get_iterator_state_name(self.data_loader_state),
|
||||||
"scheduler": self.scheduler.state_dict(),
|
"scheduler": self.scheduler.state_dict(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.step = state_dict["step"]
|
self.step = state_dict["step"]
|
||||||
self.acc_step = state_dict["acc_step"]
|
self.acc_step = state_dict["acc_step"]
|
||||||
self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"])
|
self.data_loader_class = state_dict["data_loader_class"]
|
||||||
|
if self.data_loader_class == "multiprocess":
|
||||||
|
self.data_loader_state = MultiprocessIteratorState(
|
||||||
|
**state_dict["data_loader_state"]
|
||||||
|
)
|
||||||
|
elif self.data_loader_class == "packing":
|
||||||
|
self.data_loader_state = PackingIteratorState(
|
||||||
|
**state_dict["data_loader_state"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid data loader class: {self.data_loader_class}")
|
||||||
self.scheduler.load_state_dict(state_dict["scheduler"])
|
self.scheduler.load_state_dict(state_dict["scheduler"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -345,7 +363,10 @@ def train(args: TrainArgs):
|
||||||
nwords_since_last_log = 0
|
nwords_since_last_log = 0
|
||||||
time_last_log = timer()
|
time_last_log = timer()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
while train_state.step < args.steps:
|
saved = False
|
||||||
|
while train_state.step < args.steps and (
|
||||||
|
args.max_steps is None or train_state.step < args.max_steps
|
||||||
|
):
|
||||||
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
|
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
|
||||||
train_state.acc_step += 1
|
train_state.acc_step += 1
|
||||||
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
|
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
|
||||||
|
@ -552,7 +573,6 @@ def train(args: TrainArgs):
|
||||||
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
||||||
)
|
)
|
||||||
|
|
||||||
saved = False
|
|
||||||
if every_n_steps(
|
if every_n_steps(
|
||||||
train_state, args.checkpoint.dump.every, acc_step=0
|
train_state, args.checkpoint.dump.every, acc_step=0
|
||||||
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||||
|
@ -567,18 +587,14 @@ def train(args: TrainArgs):
|
||||||
if args.eval is not None and every_n_steps(
|
if args.eval is not None and every_n_steps(
|
||||||
train_state, args.checkpoint.eval.every, acc_step=0
|
train_state, args.checkpoint.eval.every, acc_step=0
|
||||||
):
|
):
|
||||||
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
eval_args = args.eval
|
||||||
|
|
||||||
eval_args = dataclass_from_dict(EvalArgs, args.eval)
|
|
||||||
|
|
||||||
eval_args.global_step = train_state.step
|
eval_args.global_step = train_state.step
|
||||||
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
|
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
|
||||||
eval_args.dump_dir = str(
|
eval_args.dump_dir = os.path.join(
|
||||||
os.path.join(
|
args.dump_dir,
|
||||||
args.dump_dir,
|
"evals",
|
||||||
"evals",
|
EVAL_FOLDER_NAME.format(train_state.step),
|
||||||
EVAL_FOLDER_NAME.format(train_state.step),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
eval_args.metric_log_dir = args.dump_dir
|
eval_args.metric_log_dir = args.dump_dir
|
||||||
if args.async_eval_gpus is None:
|
if args.async_eval_gpus is None:
|
||||||
|
@ -619,6 +635,9 @@ def train(args: TrainArgs):
|
||||||
args,
|
args,
|
||||||
device_mesh=world_mesh,
|
device_mesh=world_mesh,
|
||||||
)
|
)
|
||||||
|
if isinstance(data_loader, MultiprocessIterator):
|
||||||
|
logger.info("Closing MP iterator before exiting")
|
||||||
|
data_loader.shutdown()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
@ -661,15 +680,7 @@ def main():
|
||||||
|
|
||||||
Plus all the default values in TrainArgs dataclass.
|
Plus all the default values in TrainArgs dataclass.
|
||||||
"""
|
"""
|
||||||
cli_args = OmegaConf.from_cli()
|
train_args = parse_args(TrainArgs)
|
||||||
file_cfg = OmegaConf.load(cli_args.config)
|
|
||||||
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
|
||||||
del cli_args.config
|
|
||||||
|
|
||||||
default_cfg = OmegaConf.create(TrainArgs().model_dump())
|
|
||||||
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
|
||||||
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
|
||||||
train_args = TrainArgs.model_validate(cfg)
|
|
||||||
if train_args.debug_dynamo:
|
if train_args.debug_dynamo:
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue