mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-21 00:59:09 +00:00
Get evals working again. (#46)
- PPL/validation: Works now and uses multi-gpu. For some reason 1 GPU differs from multi-GPU, can debug in a followup PR - Generation evals likely work, but are very slow, so disabled for now Test Plan: ``` torchrun --nproc-per-node 8 -m bytelatent.eval config=../internal-blt/configs/eval.yaml ```
This commit is contained in:
parent
63913e4dba
commit
7517ac2a9f
6 changed files with 276 additions and 101 deletions
|
@ -270,6 +270,10 @@ class EvalArgs(BaseModel):
|
||||||
dump_dir: str | None = None
|
dump_dir: str | None = None
|
||||||
ckpt_dir: str | None = None
|
ckpt_dir: str | None = None
|
||||||
metric_log_dir: str | None = None
|
metric_log_dir: str | None = None
|
||||||
|
|
||||||
|
run_ppl: bool = True
|
||||||
|
run_tasks: bool = False
|
||||||
|
|
||||||
generator: PackedCausalTransformerGeneratorArgs = (
|
generator: PackedCausalTransformerGeneratorArgs = (
|
||||||
PackedCausalTransformerGeneratorArgs()
|
PackedCausalTransformerGeneratorArgs()
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from functools import lru_cache, partial, reduce
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# for no recompute ops
|
# for no recompute ops
|
||||||
|
@ -78,6 +79,40 @@ class DistributedArgs(BaseModel):
|
||||||
|
|
||||||
spawn_method: str = "forkserver"
|
spawn_method: str = "forkserver"
|
||||||
|
|
||||||
|
def configure_world(self):
|
||||||
|
pass
|
||||||
|
if self.dp_replicate * self.dp_shard * self.tp_size != get_world_size():
|
||||||
|
logging.info("Modifying TrainArgs distributed config")
|
||||||
|
assert get_world_size() % self.dp_shard == 0
|
||||||
|
logging.info("World size: %s", get_world_size())
|
||||||
|
logging.info(
|
||||||
|
"Existing setting: train_args.distributed.dp_shard=%s",
|
||||||
|
self.dp_shard,
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
||||||
|
get_world_size() // self.dp_shard,
|
||||||
|
self.dp_replicate,
|
||||||
|
)
|
||||||
|
self.dp_replicate = get_world_size() // self.dp_shard
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
||||||
|
self.dp_replicate,
|
||||||
|
self.dp_replicate // self.tp_size,
|
||||||
|
self.tp_size,
|
||||||
|
)
|
||||||
|
assert self.dp_replicate % self.tp_size == 0
|
||||||
|
self.dp_replicate = self.dp_replicate // self.tp_size
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Setting Data Parallel size to {self.dp_replicate * self.dp_shard}"
|
||||||
|
)
|
||||||
|
assert self.dp_replicate * self.dp_shard * self.tp_size == get_world_size()
|
||||||
|
|
||||||
|
if self.fsdp_type == "no_shard":
|
||||||
|
assert self.dp_shard == 1 and self.dp_replicate == get_world_size()
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentArgs(BaseModel):
|
class EnvironmentArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
@ -151,6 +186,13 @@ def dist_mean_dict(x):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
|
||||||
|
if isinstance(num, (torch.Tensor, np.ndarray)):
|
||||||
|
return num.item()
|
||||||
|
else:
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_is_torch_run() -> bool:
|
def get_is_torch_run() -> bool:
|
||||||
return os.environ.get("LOCAL_RANK") is not None
|
return os.environ.get("LOCAL_RANK") is not None
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -10,22 +11,48 @@ import torch
|
||||||
from lm_eval import simple_evaluate
|
from lm_eval import simple_evaluate
|
||||||
from lm_eval.api.instance import Instance
|
from lm_eval.api.instance import Instance
|
||||||
from lm_eval.api.model import LM
|
from lm_eval.api.model import LM
|
||||||
|
from rich.progress import track
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from bytelatent.args import EvalArgs, ValidationArgs
|
from bytelatent.args import (
|
||||||
|
EvalArgs,
|
||||||
|
TrainArgs,
|
||||||
|
ValidationArgs,
|
||||||
|
find_and_sanitize_chunks,
|
||||||
|
)
|
||||||
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||||
from bytelatent.config_parser import parse_args_to_pydantic_model
|
from bytelatent.config_parser import parse_args_to_pydantic_model
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||||
|
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
||||||
|
from bytelatent.data.iterators.packing_iterator import (
|
||||||
|
PackingArgs,
|
||||||
|
PackingIterator,
|
||||||
|
PackingMode,
|
||||||
|
)
|
||||||
|
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
||||||
|
from bytelatent.data.iterators.sequence_iterator import (
|
||||||
|
SequenceIterator,
|
||||||
|
SequencePackingArgs,
|
||||||
|
)
|
||||||
|
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
||||||
from bytelatent.distributed import (
|
from bytelatent.distributed import (
|
||||||
DistributedArgs,
|
DistributedArgs,
|
||||||
dist_mean_dict,
|
dist_mean_dict,
|
||||||
|
dist_sum,
|
||||||
|
get_device_mesh,
|
||||||
get_global_rank,
|
get_global_rank,
|
||||||
get_world_size,
|
get_world_size,
|
||||||
setup_torch_distributed,
|
setup_torch_distributed,
|
||||||
|
to_py_num,
|
||||||
)
|
)
|
||||||
from bytelatent.generate import (
|
from bytelatent.generate import (
|
||||||
PackedCausalTransformerGenerator,
|
PackedCausalTransformerGenerator,
|
||||||
load_consolidated_model_and_tokenizer,
|
load_consolidated_model_and_tokenizer,
|
||||||
)
|
)
|
||||||
|
from bytelatent.model.blt import ByteLatentTransformer
|
||||||
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
from bytelatent.transformer import LMTransformer
|
||||||
|
|
||||||
EVAL_FOLDER_NAME = "{:010d}"
|
EVAL_FOLDER_NAME = "{:010d}"
|
||||||
|
|
||||||
|
@ -113,19 +140,134 @@ class EvalHarnessLM(LM):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
@torch.no_grad()
|
||||||
srcs = {}
|
def eval_ppl_on_path(
|
||||||
|
*,
|
||||||
|
world_rank: int,
|
||||||
|
world_size: int,
|
||||||
|
model: LMTransformer | ByteLatentTransformer,
|
||||||
|
tokenizer_args: TokenizerArgs,
|
||||||
|
patcher_args: PatcherArgs,
|
||||||
|
add_patches: bool,
|
||||||
|
path: str,
|
||||||
|
batch_size: int,
|
||||||
|
arrow_batch_size: int,
|
||||||
|
max_n_docs: int | None,
|
||||||
|
s3_profile: str | None = None,
|
||||||
|
):
|
||||||
|
model.eval()
|
||||||
|
tokenizer = tokenizer_args.build()
|
||||||
|
seq_len = model.get_output_seq_len()
|
||||||
|
chunks = find_and_sanitize_chunks(
|
||||||
|
path,
|
||||||
|
world_size=1,
|
||||||
|
file_pattern="*.val.jsonl",
|
||||||
|
s3_profile=s3_profile,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(chunks) == 1
|
||||||
|
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
||||||
|
chunk = chunks[0]
|
||||||
|
arrow_iterator = ArrowFileIterator(
|
||||||
|
file_path=chunk,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=None,
|
||||||
|
worker_id=world_rank,
|
||||||
|
num_workers=world_size,
|
||||||
|
arrow_batch_size=arrow_batch_size,
|
||||||
|
s3_profile=s3_profile,
|
||||||
|
file_format="json",
|
||||||
|
)
|
||||||
|
if max_n_docs is not None:
|
||||||
|
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
|
||||||
|
preprocess_iterator = PreprocessIterator(
|
||||||
|
arrow_iterator,
|
||||||
|
patcher_args=patcher_args,
|
||||||
|
tokenizer_args=tokenizer_args,
|
||||||
|
add_patches=add_patches,
|
||||||
|
)
|
||||||
|
sequence_iterator = SequenceIterator(
|
||||||
|
preprocess_iterator,
|
||||||
|
sequence_packing_args=SequencePackingArgs(
|
||||||
|
output_seq_len=seq_len,
|
||||||
|
# Effectively disables shuffles
|
||||||
|
buffer_size=1,
|
||||||
|
),
|
||||||
|
rng_state=None,
|
||||||
|
)
|
||||||
|
packing_args = PackingArgs(
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
# TODO: make these seq lens worth with blt
|
||||||
|
max_length=seq_len,
|
||||||
|
pad_to_max_length=True,
|
||||||
|
enable_byte_ngrams=False,
|
||||||
|
pad_id=tokenizer.boe_id,
|
||||||
|
packing_mode=PackingMode.BYTES,
|
||||||
|
)
|
||||||
|
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
|
||||||
|
total_loss = 0.0
|
||||||
|
n_bytes = 0
|
||||||
|
batch_iterator = packing_iterator.create_iter()
|
||||||
|
for batch in batch_iterator:
|
||||||
|
x = torch.from_numpy(batch.x).cuda()
|
||||||
|
y = torch.from_numpy(batch.y).cuda()
|
||||||
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||||
|
if tokenizer_args.name in ["bytes", "blt"]:
|
||||||
|
n_bytes += y.numel() if mask is None else mask.sum().item()
|
||||||
|
pred = model(x)
|
||||||
|
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
||||||
|
total_loss += loss.item()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
all_n_bytes = to_py_num(dist_sum(n_bytes))
|
||||||
|
all_total_loss = to_py_num(dist_sum(total_loss))
|
||||||
|
return {
|
||||||
|
"n_bytes": all_n_bytes,
|
||||||
|
"n_bytes_gpu": n_bytes,
|
||||||
|
"loss_sum": all_total_loss,
|
||||||
|
"loss_sum_gpu": total_loss,
|
||||||
|
"loss_mean": all_total_loss / all_n_bytes,
|
||||||
|
"loss_mean_gpu": total_loss / n_bytes,
|
||||||
|
"ppl": math.exp(all_total_loss / all_n_bytes) if all_n_bytes > 0 else 0.0,
|
||||||
|
"bpb": all_total_loss / math.log(2) / all_n_bytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
||||||
|
srcs = []
|
||||||
for src in val_args.sources:
|
for src in val_args.sources:
|
||||||
path = os.path.join(val_args.root_dir, src)
|
path = os.path.join(val_args.root_dir, src)
|
||||||
srcs[path] = 1.0
|
srcs.append(path)
|
||||||
|
|
||||||
for src in train_cfg.data.sources:
|
for src in train_cfg.data.sources:
|
||||||
path = os.path.join(train_cfg.data.root_dir, src)
|
path = os.path.join(train_cfg.data.root_dir, src)
|
||||||
srcs[path] = 1.0
|
srcs.append(path)
|
||||||
|
|
||||||
multi_state = init_choice_state(
|
path_to_iter = {}
|
||||||
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
for path in srcs:
|
||||||
)
|
chunks = find_and_sanitize_chunks(
|
||||||
path_to_iter = setup_sources(multi_state)
|
path,
|
||||||
|
world_size=1,
|
||||||
|
file_pattern="*.val.jsonl",
|
||||||
|
s3_profile=train_cfg.data.s3_profile,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(chunks) == 1
|
||||||
|
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
||||||
|
chunk = chunks[0]
|
||||||
|
iterator = ArrowFileIterator(
|
||||||
|
dataset_files=[chunk],
|
||||||
|
file_path=None,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=None,
|
||||||
|
worker_id=0,
|
||||||
|
num_workers=1,
|
||||||
|
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
||||||
|
s3_profile=train_cfg.data.s3_profile,
|
||||||
|
file_format="json",
|
||||||
|
)
|
||||||
|
path_to_iter[path] = iterator
|
||||||
|
|
||||||
max_gen_len = generator.max_gen_len
|
max_gen_len = generator.max_gen_len
|
||||||
# We temporarily lower max gen len
|
# We temporarily lower max gen len
|
||||||
|
@ -133,16 +275,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||||
|
|
||||||
all_val_metrics = {}
|
all_val_metrics = {}
|
||||||
for src in path_to_iter:
|
for src in path_to_iter:
|
||||||
jsonl_iterator = path_to_iter[src]
|
example_iterator = path_to_iter[src].create_iter()
|
||||||
texts = []
|
texts = []
|
||||||
logger.info(f"Running validation on {src}...")
|
logger.info(f"Running validation on {src}...")
|
||||||
for step, (content, state) in enumerate(jsonl_iterator):
|
for step, example in enumerate(example_iterator):
|
||||||
if state["current_iter"] > 0 or (
|
texts.append(example.text)
|
||||||
val_args.max_steps is not None and step >= val_args.max_steps
|
|
||||||
):
|
|
||||||
break
|
|
||||||
content_key = "text" if ("text" in content) else "content"
|
|
||||||
texts.append(content[content_key])
|
|
||||||
|
|
||||||
_, loglikelihood, _ = generator.generate(texts)
|
_, loglikelihood, _ = generator.generate(texts)
|
||||||
|
|
||||||
|
@ -174,8 +311,18 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||||
|
|
||||||
|
|
||||||
def launch_eval(eval_args: EvalArgs):
|
def launch_eval(eval_args: EvalArgs):
|
||||||
|
assert eval_args.dump_dir is not None
|
||||||
|
assert eval_args.ckpt_dir is not None
|
||||||
|
distributed_args = DistributedArgs()
|
||||||
|
distributed_args.configure_world()
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
setup_torch_distributed(DistributedArgs())
|
setup_torch_distributed(distributed_args)
|
||||||
|
|
||||||
|
world_mesh = get_device_mesh(distributed_args)
|
||||||
|
dp_mesh = world_mesh["dp_replicate"]
|
||||||
|
assert distributed_args.dp_shard == 1
|
||||||
|
world_size = dp_mesh.size()
|
||||||
|
world_rank = dp_mesh.get_local_rank()
|
||||||
|
|
||||||
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
|
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
|
||||||
if (
|
if (
|
||||||
|
@ -187,7 +334,7 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
else:
|
else:
|
||||||
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||||
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||||
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
||||||
|
|
||||||
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||||
|
@ -200,35 +347,67 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||||
consolidate_path,
|
consolidate_path,
|
||||||
)
|
)
|
||||||
logger.info("Model loaded")
|
|
||||||
model.eval()
|
model.eval()
|
||||||
generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
|
logger.info("Model loaded")
|
||||||
|
|
||||||
|
ppl_results = None
|
||||||
|
if eval_args.run_ppl:
|
||||||
|
assert eval_args.validation is not None
|
||||||
|
if len(eval_args.validation.sources) > 0:
|
||||||
|
ppl_results = {}
|
||||||
|
logger.info("Starting PPL evaluation on validation sets")
|
||||||
|
for source in eval_args.validation.sources:
|
||||||
|
ppl_results[source] = eval_ppl_on_path(
|
||||||
|
world_rank=world_rank,
|
||||||
|
world_size=world_size,
|
||||||
|
model=model,
|
||||||
|
tokenizer_args=train_cfg.data.tokenizer_args,
|
||||||
|
# TODO: Don't hardcode, modify based on model
|
||||||
|
patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
|
||||||
|
add_patches=False,
|
||||||
|
path=os.path.join(eval_args.validation.root_dir, source),
|
||||||
|
max_n_docs=eval_args.validation.max_n_docs,
|
||||||
|
batch_size=8,
|
||||||
|
arrow_batch_size=100,
|
||||||
|
s3_profile="blt",
|
||||||
|
)
|
||||||
|
|
||||||
|
task_results = None
|
||||||
|
if eval_args.run_tasks:
|
||||||
|
assert eval_args.generator is not None
|
||||||
|
assert eval_args.harness is not None
|
||||||
|
generator = PackedCausalTransformerGenerator(
|
||||||
|
eval_args.generator, model, tokenizer
|
||||||
|
)
|
||||||
|
wrap = EvalHarnessLM(generator)
|
||||||
|
# TODO: This needs to be checked/sped up
|
||||||
|
task_results = simple_evaluate(wrap, **eval_args.harness.model_dump())
|
||||||
|
|
||||||
|
results = {"ppl": ppl_results, "tasks": task_results}
|
||||||
|
# TODO: Serial and Parallel yield slightly different number of bytes, debug this later,
|
||||||
|
# leaving this log statement here to help with that.
|
||||||
|
# logging.info("Rank: %s Results: %s", world_rank, results)
|
||||||
|
|
||||||
wrap = EvalHarnessLM(generator)
|
|
||||||
# Redo
|
|
||||||
results = simple_evaluate(wrap, eval_args.harness.model_dump())
|
|
||||||
val_results = None
|
|
||||||
if eval_args.validation:
|
|
||||||
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
|
||||||
if get_global_rank() == 0:
|
if get_global_rank() == 0:
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
||||||
f.write(json.dumps(results))
|
f.write(json.dumps(results))
|
||||||
logger.info(f"All evaluation results: {results['results']}")
|
logger.info(f"All evaluation results: {results}")
|
||||||
if val_results is not None:
|
if ppl_results is not None:
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
||||||
f.write(json.dumps(val_results))
|
f.write(json.dumps(ppl_results))
|
||||||
logger.info(f"All validation results: {val_results}")
|
logger.info(f"All validation results: {ppl_results}")
|
||||||
|
|
||||||
if eval_args.metric_log_dir and get_global_rank() == 0:
|
if eval_args.metric_log_dir and get_global_rank() == 0:
|
||||||
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
||||||
|
|
||||||
logger.info(f"Writing metric logs to {metric_log_path}")
|
logger.info(f"Writing metric logs to {metric_log_path}")
|
||||||
timestamp = {
|
timestamp: dict[str, int | str] = {
|
||||||
"created_at": datetime.utcnow().isoformat(),
|
"created_at": datetime.utcnow().isoformat(),
|
||||||
}
|
}
|
||||||
if eval_args.global_step is not None:
|
if eval_args.global_step is not None:
|
||||||
timestamp["global_step"] = eval_args.global_step
|
timestamp["global_step"] = eval_args.global_step
|
||||||
print(
|
print(
|
||||||
json.dumps(timestamp | results["results"]),
|
json.dumps(timestamp | results),
|
||||||
file=fs.open(metric_log_path, mode="a"),
|
file=fs.open(metric_log_path, mode="a"),
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
@ -236,18 +415,16 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
val_log_path = os.path.join(
|
val_log_path = os.path.join(
|
||||||
eval_args.metric_log_dir, "metrics.validation.jsonl"
|
eval_args.metric_log_dir, "metrics.validation.jsonl"
|
||||||
)
|
)
|
||||||
if val_results is not None:
|
if ppl_results is not None:
|
||||||
print(
|
print(
|
||||||
json.dumps(timestamp | val_results),
|
json.dumps(timestamp | ppl_results),
|
||||||
file=fs.open(val_log_path, mode="a"),
|
file=fs.open(val_log_path, mode="a"),
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
del generator
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
eval_args = parse_args(EvalArgs)
|
eval_args = parse_args_to_pydantic_model(EvalArgs)
|
||||||
launch_eval(eval_args)
|
launch_eval(eval_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
|
||||||
):
|
):
|
||||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||||
fs = get_fs(train_args_path)
|
fs = get_fs(train_args_path)
|
||||||
with fs.open(train_args_path) as f:
|
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||||
train_args = TrainArgs.model_validate_json(f.read())
|
|
||||||
|
|
||||||
if train_args.train_entropy_model:
|
if train_args.train_entropy_model:
|
||||||
model_args = train_args.entropy_model
|
model_args = train_args.entropy_model
|
||||||
|
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
|
||||||
train_args.distributed.model_dtype
|
train_args.distributed.model_dtype
|
||||||
]
|
]
|
||||||
tokenizer = train_args.data.tokenizer_args.build()
|
tokenizer = train_args.data.tokenizer_args.build()
|
||||||
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
|
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||||
|
st_dict = torch.load(f, weights_only=True)
|
||||||
model.load_state_dict(st_dict["model"])
|
model.load_state_dict(st_dict["model"])
|
||||||
model = model.cuda().eval()
|
model = model.cuda().eval()
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
|
|
|
@ -55,7 +55,7 @@ class LoggingArgs(BaseModel):
|
||||||
class MetricLogger:
|
class MetricLogger:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
outdir: Path,
|
outdir: str,
|
||||||
# args: TrainArgs
|
# args: TrainArgs
|
||||||
args: Any | None = None,
|
args: Any | None = None,
|
||||||
fs: fsspec.AbstractFileSystem | None = None,
|
fs: fsspec.AbstractFileSystem | None = None,
|
||||||
|
|
|
@ -48,6 +48,7 @@ from bytelatent.distributed import (
|
||||||
requeue_slurm_job,
|
requeue_slurm_job,
|
||||||
setup_env,
|
setup_env,
|
||||||
setup_torch_distributed,
|
setup_torch_distributed,
|
||||||
|
to_py_num,
|
||||||
)
|
)
|
||||||
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
||||||
from bytelatent.logger import init_logger
|
from bytelatent.logger import init_logger
|
||||||
|
@ -91,13 +92,6 @@ def get_iterator_state_name(iterator_state):
|
||||||
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
||||||
|
|
||||||
|
|
||||||
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
|
|
||||||
if isinstance(num, (torch.Tensor, np.ndarray)):
|
|
||||||
return num.item()
|
|
||||||
else:
|
|
||||||
return num
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Make this pydantic based instead of data class based
|
# TODO: Make this pydantic based instead of data class based
|
||||||
# TODO: Generalize this to any iterator state
|
# TODO: Generalize this to any iterator state
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -154,57 +148,13 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
||||||
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
||||||
args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
|
args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
|
||||||
|
|
||||||
data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
|
if args.data.root_dir is not None:
|
||||||
for source in args.data.sources:
|
data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
|
||||||
data_path = os.path.join(args.data.root_dir, source)
|
for source in args.data.sources:
|
||||||
assert data_fs.exists(data_path), f"{data_path} doesn't exist"
|
data_path = os.path.join(args.data.root_dir, source)
|
||||||
|
assert data_fs.exists(data_path), f"{data_path} doesn't exist"
|
||||||
|
|
||||||
if (
|
args.distributed.configure_world()
|
||||||
args.distributed.dp_replicate
|
|
||||||
* args.distributed.dp_shard
|
|
||||||
* args.distributed.tp_size
|
|
||||||
!= get_world_size()
|
|
||||||
):
|
|
||||||
logging.info("Modifying TrainArgs distributed config")
|
|
||||||
assert get_world_size() % args.distributed.dp_shard == 0
|
|
||||||
logging.info("World size: %s", get_world_size())
|
|
||||||
logging.info(
|
|
||||||
"Existing setting: train_args.distributed.dp_shard=%s",
|
|
||||||
args.distributed.dp_shard,
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
|
||||||
get_world_size() // args.distributed.dp_shard,
|
|
||||||
args.distributed.dp_replicate,
|
|
||||||
)
|
|
||||||
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
|
||||||
args.distributed.dp_replicate,
|
|
||||||
args.distributed.dp_replicate // args.distributed.tp_size,
|
|
||||||
args.distributed.tp_size,
|
|
||||||
)
|
|
||||||
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
|
||||||
args.distributed.dp_replicate = (
|
|
||||||
args.distributed.dp_replicate // args.distributed.tp_size
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
args.distributed.dp_replicate
|
|
||||||
* args.distributed.dp_shard
|
|
||||||
* args.distributed.tp_size
|
|
||||||
== get_world_size()
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.distributed.fsdp_type == "no_shard":
|
|
||||||
assert (
|
|
||||||
args.distributed.dp_shard == 1
|
|
||||||
and args.distributed.dp_replicate == get_world_size()
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.model is not None:
|
if args.model is not None:
|
||||||
args.model.max_seqlen = args.data.seq_len
|
args.model.max_seqlen = args.data.seq_len
|
||||||
|
@ -243,7 +193,9 @@ def set_preemption_flag(signum, frame):
|
||||||
preemption_flag["flag"] = True
|
preemption_flag["flag"] = True
|
||||||
|
|
||||||
|
|
||||||
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
|
def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None):
|
||||||
|
if freq < 0:
|
||||||
|
return False
|
||||||
test = train_state.step % freq == 0
|
test = train_state.step % freq == 0
|
||||||
if acc_step is not None:
|
if acc_step is not None:
|
||||||
test = test and (train_state.acc_step == acc_step)
|
test = test and (train_state.acc_step == acc_step)
|
||||||
|
@ -272,7 +224,7 @@ def train(args: TrainArgs):
|
||||||
tokenizer = args.data.tokenizer_args.build()
|
tokenizer = args.data.tokenizer_args.build()
|
||||||
validate_train_args(
|
validate_train_args(
|
||||||
args,
|
args,
|
||||||
tokenizer.n_words,
|
tokenizer.get_vocab_size(),
|
||||||
)
|
)
|
||||||
dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
|
dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
|
||||||
if get_is_master():
|
if get_is_master():
|
||||||
|
|
Loading…
Add table
Reference in a new issue