mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-23 01:59:08 +00:00
Merge 9446c1ee5c
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
2d4f277596
4 changed files with 179 additions and 31 deletions
bytelatent
|
@ -2,6 +2,7 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
@ -10,11 +11,27 @@ import torch
|
|||
from lm_eval import simple_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from rich.progress import track
|
||||
from torch.nn import functional as F
|
||||
|
||||
from bytelatent.args import EvalArgs, ValidationArgs
|
||||
from bytelatent.args import (
|
||||
EvalArgs,
|
||||
TrainArgs,
|
||||
ValidationArgs,
|
||||
find_and_sanitize_chunks,
|
||||
)
|
||||
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||
from bytelatent.config_parser import parse_args_to_pydantic_model
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
||||
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
||||
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
||||
from bytelatent.data.iterators.sequence_iterator import (
|
||||
SequenceIterator,
|
||||
SequencePackingArgs,
|
||||
)
|
||||
from bytelatent.data.patcher import PatcherArgs
|
||||
from bytelatent.distributed import (
|
||||
DistributedArgs,
|
||||
dist_mean_dict,
|
||||
|
@ -26,6 +43,9 @@ from bytelatent.generate import (
|
|||
PackedCausalTransformerGenerator,
|
||||
load_consolidated_model_and_tokenizer,
|
||||
)
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
from bytelatent.transformer import LMTransformer
|
||||
|
||||
EVAL_FOLDER_NAME = "{:010d}"
|
||||
|
||||
|
@ -113,19 +133,125 @@ class EvalHarnessLM(LM):
|
|||
return results
|
||||
|
||||
|
||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||
srcs = {}
|
||||
@torch.no_grad()
|
||||
def eval_ppl_on_path(
|
||||
*,
|
||||
model: LMTransformer | ByteLatentTransformer,
|
||||
tokenizer_args: TokenizerArgs,
|
||||
patcher_args: PatcherArgs,
|
||||
add_patches: bool,
|
||||
path: str,
|
||||
batch_size: int,
|
||||
arrow_batch_size: int,
|
||||
max_n_docs: int | None,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
model.eval()
|
||||
tokenizer = tokenizer_args.build()
|
||||
seq_len = model.get_output_seq_len()
|
||||
chunks = find_and_sanitize_chunks(
|
||||
path,
|
||||
world_size=1,
|
||||
file_pattern="*.val.jsonl",
|
||||
s3_profile=s3_profile,
|
||||
)
|
||||
assert (
|
||||
len(chunks) == 1
|
||||
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
||||
chunk = chunks[0]
|
||||
arrow_iterator = ArrowFileIterator(
|
||||
file_path=chunk,
|
||||
preprocess_dir=None,
|
||||
entropy_model_name=None,
|
||||
worker_id=0,
|
||||
num_workers=1,
|
||||
arrow_batch_size=arrow_batch_size,
|
||||
s3_profile=s3_profile,
|
||||
file_format="json",
|
||||
)
|
||||
if max_n_docs is not None:
|
||||
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
|
||||
preprocess_iterator = PreprocessIterator(
|
||||
arrow_iterator,
|
||||
patcher_args=patcher_args,
|
||||
tokenizer_args=tokenizer_args,
|
||||
add_patches=add_patches,
|
||||
)
|
||||
sequence_iterator = SequenceIterator(
|
||||
preprocess_iterator,
|
||||
sequence_packing_args=SequencePackingArgs(
|
||||
output_seq_len=seq_len,
|
||||
# Effectively disables shuffles
|
||||
buffer_size=1,
|
||||
),
|
||||
rng_state=None,
|
||||
)
|
||||
packing_args = PackingArgs(
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
# TODO: make these seq lens worth with blt
|
||||
max_length=seq_len,
|
||||
tokenizer_name=tokenizer_args.name,
|
||||
pad_to_max_length=True,
|
||||
enable_byte_ngrams=False,
|
||||
pad_id=0 if tokenizer_args.name == "bytes" else tokenizer.boe_id,
|
||||
)
|
||||
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
|
||||
total_loss = 0.0
|
||||
n_bytes = 0
|
||||
batch_iterator = packing_iterator.create_iter()
|
||||
for batch in batch_iterator:
|
||||
x = torch.from_numpy(batch.x).cuda()
|
||||
y = torch.from_numpy(batch.y).cuda()
|
||||
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||
if tokenizer_args.name in ["bytes", "blt"]:
|
||||
n_bytes += y.numel() if mask is None else mask.sum().item()
|
||||
pred = model(x)
|
||||
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
||||
total_loss += loss.item()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return {
|
||||
"n_bytes": n_bytes,
|
||||
"loss_sum": total_loss,
|
||||
"ppl": math.exp(total_loss / n_bytes) if n_bytes > 0 else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
||||
srcs = []
|
||||
for src in val_args.sources:
|
||||
path = os.path.join(val_args.root_dir, src)
|
||||
srcs[path] = 1.0
|
||||
srcs.append(path)
|
||||
|
||||
for src in train_cfg.data.sources:
|
||||
path = os.path.join(train_cfg.data.root_dir, src)
|
||||
srcs[path] = 1.0
|
||||
srcs.append(path)
|
||||
|
||||
multi_state = init_choice_state(
|
||||
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
||||
)
|
||||
path_to_iter = setup_sources(multi_state)
|
||||
path_to_iter = {}
|
||||
for path in srcs:
|
||||
chunks = find_and_sanitize_chunks(
|
||||
path,
|
||||
world_size=1,
|
||||
file_pattern="*.val.jsonl",
|
||||
s3_profile=train_cfg.data.s3_profile,
|
||||
)
|
||||
assert (
|
||||
len(chunks) == 1
|
||||
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
||||
chunk = chunks[0]
|
||||
iterator = ArrowFileIterator(
|
||||
dataset_files=[chunk],
|
||||
file_path=None,
|
||||
preprocess_dir=None,
|
||||
entropy_model_name=None,
|
||||
worker_id=0,
|
||||
num_workers=1,
|
||||
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
||||
s3_profile=train_cfg.data.s3_profile,
|
||||
file_format="json",
|
||||
)
|
||||
path_to_iter[path] = iterator
|
||||
|
||||
max_gen_len = generator.max_gen_len
|
||||
# We temporarily lower max gen len
|
||||
|
@ -133,16 +259,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
|||
|
||||
all_val_metrics = {}
|
||||
for src in path_to_iter:
|
||||
jsonl_iterator = path_to_iter[src]
|
||||
example_iterator = path_to_iter[src].create_iter()
|
||||
texts = []
|
||||
logger.info(f"Running validation on {src}...")
|
||||
for step, (content, state) in enumerate(jsonl_iterator):
|
||||
if state["current_iter"] > 0 or (
|
||||
val_args.max_steps is not None and step >= val_args.max_steps
|
||||
):
|
||||
break
|
||||
content_key = "text" if ("text" in content) else "content"
|
||||
texts.append(content[content_key])
|
||||
for step, example in enumerate(example_iterator):
|
||||
texts.append(example.text)
|
||||
|
||||
_, loglikelihood, _ = generator.generate(texts)
|
||||
|
||||
|
@ -187,7 +308,7 @@ def launch_eval(eval_args: EvalArgs):
|
|||
else:
|
||||
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
||||
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
||||
|
||||
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||
|
@ -200,16 +321,39 @@ def launch_eval(eval_args: EvalArgs):
|
|||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||
consolidate_path,
|
||||
)
|
||||
logger.info("Model loaded")
|
||||
model.eval()
|
||||
logger.info("Model loaded")
|
||||
|
||||
if eval_args.validation:
|
||||
logger.info("Starting PPL evaluation on validation sets")
|
||||
# val_results = eval_on_val(
|
||||
val_results = eval_ppl_on_path(
|
||||
model=model,
|
||||
tokenizer_args=train_cfg.data.tokenizer_args,
|
||||
# TODO: Don't hardcode, modify based on model
|
||||
patcher_args=PatcherArgs(patching_mode="byte"),
|
||||
add_patches=False,
|
||||
path="/checkpoint/amaia/explore/datasets/dclm_baseline_1.0/",
|
||||
max_n_docs=eval_args.validation.max_n_docs,
|
||||
batch_size=8,
|
||||
arrow_batch_size=100,
|
||||
s3_profile="blt",
|
||||
)
|
||||
print(val_results)
|
||||
|
||||
raise NotImplementedException()
|
||||
|
||||
generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
|
||||
|
||||
wrap = EvalHarnessLM(generator)
|
||||
# Redo
|
||||
results = simple_evaluate(wrap, eval_args.harness.model_dump())
|
||||
# results = simple_evaluate(wrap, **eval_args.harness.model_dump())
|
||||
results = {"results": []}
|
||||
|
||||
val_results = None
|
||||
if eval_args.validation:
|
||||
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
||||
|
||||
if get_global_rank() == 0:
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
||||
f.write(json.dumps(results))
|
||||
|
@ -218,6 +362,7 @@ def launch_eval(eval_args: EvalArgs):
|
|||
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
||||
f.write(json.dumps(val_results))
|
||||
logger.info(f"All validation results: {val_results}")
|
||||
|
||||
if eval_args.metric_log_dir and get_global_rank() == 0:
|
||||
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
||||
|
||||
|
@ -247,7 +392,7 @@ def launch_eval(eval_args: EvalArgs):
|
|||
|
||||
|
||||
def main():
|
||||
eval_args = parse_args(EvalArgs)
|
||||
eval_args = parse_args_to_pydantic_model(EvalArgs)
|
||||
launch_eval(eval_args)
|
||||
|
||||
|
||||
|
|
|
@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
|
|||
):
|
||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||
fs = get_fs(train_args_path)
|
||||
with fs.open(train_args_path) as f:
|
||||
train_args = TrainArgs.model_validate_json(f.read())
|
||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||
|
||||
if train_args.train_entropy_model:
|
||||
model_args = train_args.entropy_model
|
||||
|
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
|
|||
train_args.distributed.model_dtype
|
||||
]
|
||||
tokenizer = train_args.data.tokenizer_args.build()
|
||||
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
|
||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||
st_dict = torch.load(f, weights_only=True)
|
||||
model.load_state_dict(st_dict["model"])
|
||||
model = model.cuda().eval()
|
||||
for param in model.parameters():
|
||||
|
|
|
@ -55,7 +55,7 @@ class LoggingArgs(BaseModel):
|
|||
class MetricLogger:
|
||||
def __init__(
|
||||
self,
|
||||
outdir: Path,
|
||||
outdir: str,
|
||||
# args: TrainArgs
|
||||
args: Any | None = None,
|
||||
fs: fsspec.AbstractFileSystem | None = None,
|
||||
|
|
|
@ -152,10 +152,11 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
|||
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
||||
args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
|
||||
|
||||
data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
|
||||
for source in args.data.sources:
|
||||
data_path = os.path.join(args.data.root_dir, source)
|
||||
assert data_fs.exists(data_path), f"{data_path} doesn't exist"
|
||||
if args.data.root_dir is not None:
|
||||
data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
|
||||
for source in args.data.sources:
|
||||
data_path = os.path.join(args.data.root_dir, source)
|
||||
assert data_fs.exists(data_path), f"{data_path} doesn't exist"
|
||||
|
||||
if (
|
||||
args.distributed.dp_replicate
|
||||
|
@ -241,7 +242,9 @@ def set_preemption_flag(signum, frame):
|
|||
preemption_flag["flag"] = True
|
||||
|
||||
|
||||
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
|
||||
def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None):
|
||||
if freq < 0:
|
||||
return False
|
||||
test = train_state.step % freq == 0
|
||||
if acc_step is not None:
|
||||
test = test and (train_state.acc_step == acc_step)
|
||||
|
@ -268,7 +271,7 @@ def train(args: TrainArgs):
|
|||
tokenizer = args.data.tokenizer_args.build()
|
||||
validate_train_args(
|
||||
args,
|
||||
tokenizer.n_words,
|
||||
tokenizer.get_vocab_size(),
|
||||
)
|
||||
dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
|
||||
if get_is_master():
|
||||
|
|
Loading…
Add table
Reference in a new issue