mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
Merge 655eca670d
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
f912535cb7
|
@ -248,8 +248,8 @@ class ValidationArgs(BaseModel):
|
||||||
|
|
||||||
class EvalArgs(BaseModel):
|
class EvalArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
dump_dir: str
|
dump_dir: str | None = None
|
||||||
ckpt_dir: str
|
ckpt_dir: str | None = None
|
||||||
metric_log_dir: str | None = None
|
metric_log_dir: str | None = None
|
||||||
generator: PackedCausalTransformerGeneratorArgs = (
|
generator: PackedCausalTransformerGeneratorArgs = (
|
||||||
PackedCausalTransformerGeneratorArgs()
|
PackedCausalTransformerGeneratorArgs()
|
||||||
|
|
|
@ -11,10 +11,16 @@ from lm_eval import simple_evaluate
|
||||||
from lm_eval.api.instance import Instance
|
from lm_eval.api.instance import Instance
|
||||||
from lm_eval.api.model import LM
|
from lm_eval.api.model import LM
|
||||||
|
|
||||||
from bytelatent.args import EvalArgs, ValidationArgs
|
from bytelatent.args import (
|
||||||
|
EvalArgs,
|
||||||
|
TrainArgs,
|
||||||
|
ValidationArgs,
|
||||||
|
find_and_sanitize_chunks,
|
||||||
|
)
|
||||||
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||||
from bytelatent.config_parser import parse_args_to_pydantic_model
|
from bytelatent.config_parser import parse_args_to_pydantic_model
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||||
from bytelatent.distributed import (
|
from bytelatent.distributed import (
|
||||||
DistributedArgs,
|
DistributedArgs,
|
||||||
dist_mean_dict,
|
dist_mean_dict,
|
||||||
|
@ -113,19 +119,40 @@ class EvalHarnessLM(LM):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
||||||
srcs = {}
|
srcs = []
|
||||||
for src in val_args.sources:
|
for src in val_args.sources:
|
||||||
path = os.path.join(val_args.root_dir, src)
|
path = os.path.join(val_args.root_dir, src)
|
||||||
srcs[path] = 1.0
|
srcs.append(path)
|
||||||
|
|
||||||
for src in train_cfg.data.sources:
|
for src in train_cfg.data.sources:
|
||||||
path = os.path.join(train_cfg.data.root_dir, src)
|
path = os.path.join(train_cfg.data.root_dir, src)
|
||||||
srcs[path] = 1.0
|
srcs.append(path)
|
||||||
|
|
||||||
multi_state = init_choice_state(
|
path_to_iter = {}
|
||||||
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
for path in srcs:
|
||||||
)
|
chunks = find_and_sanitize_chunks(
|
||||||
path_to_iter = setup_sources(multi_state)
|
path,
|
||||||
|
world_size=1,
|
||||||
|
file_pattern="*.val.jsonl",
|
||||||
|
s3_profile=train_cfg.data.s3_profile,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(chunks) == 1
|
||||||
|
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
||||||
|
chunk = chunks[0]
|
||||||
|
iterator = ArrowFileIterator(
|
||||||
|
dataset_files=[chunk],
|
||||||
|
file_path=None,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=None,
|
||||||
|
worker_id=0,
|
||||||
|
num_workers=1,
|
||||||
|
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
||||||
|
s3_profile=train_cfg.data.s3_profile,
|
||||||
|
file_format="json",
|
||||||
|
)
|
||||||
|
path_to_iter[path] = iterator
|
||||||
|
|
||||||
max_gen_len = generator.max_gen_len
|
max_gen_len = generator.max_gen_len
|
||||||
# We temporarily lower max gen len
|
# We temporarily lower max gen len
|
||||||
|
@ -133,16 +160,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||||
|
|
||||||
all_val_metrics = {}
|
all_val_metrics = {}
|
||||||
for src in path_to_iter:
|
for src in path_to_iter:
|
||||||
jsonl_iterator = path_to_iter[src]
|
example_iterator = path_to_iter[src].create_iter()
|
||||||
texts = []
|
texts = []
|
||||||
logger.info(f"Running validation on {src}...")
|
logger.info(f"Running validation on {src}...")
|
||||||
for step, (content, state) in enumerate(jsonl_iterator):
|
for step, example in enumerate(example_iterator):
|
||||||
if state["current_iter"] > 0 or (
|
texts.append(example.text)
|
||||||
val_args.max_steps is not None and step >= val_args.max_steps
|
|
||||||
):
|
|
||||||
break
|
|
||||||
content_key = "text" if ("text" in content) else "content"
|
|
||||||
texts.append(content[content_key])
|
|
||||||
|
|
||||||
_, loglikelihood, _ = generator.generate(texts)
|
_, loglikelihood, _ = generator.generate(texts)
|
||||||
|
|
||||||
|
@ -187,7 +209,7 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
else:
|
else:
|
||||||
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||||
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||||
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
||||||
|
|
||||||
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||||
|
@ -206,10 +228,13 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
|
|
||||||
wrap = EvalHarnessLM(generator)
|
wrap = EvalHarnessLM(generator)
|
||||||
# Redo
|
# Redo
|
||||||
results = simple_evaluate(wrap, eval_args.harness.model_dump())
|
# results = simple_evaluate(wrap, **eval_args.harness.model_dump())
|
||||||
|
results = {"results": []}
|
||||||
|
|
||||||
val_results = None
|
val_results = None
|
||||||
if eval_args.validation:
|
if eval_args.validation:
|
||||||
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
||||||
|
|
||||||
if get_global_rank() == 0:
|
if get_global_rank() == 0:
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
||||||
f.write(json.dumps(results))
|
f.write(json.dumps(results))
|
||||||
|
@ -218,6 +243,7 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
||||||
f.write(json.dumps(val_results))
|
f.write(json.dumps(val_results))
|
||||||
logger.info(f"All validation results: {val_results}")
|
logger.info(f"All validation results: {val_results}")
|
||||||
|
|
||||||
if eval_args.metric_log_dir and get_global_rank() == 0:
|
if eval_args.metric_log_dir and get_global_rank() == 0:
|
||||||
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
||||||
|
|
||||||
|
@ -247,7 +273,7 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
eval_args = parse_args(EvalArgs)
|
eval_args = parse_args_to_pydantic_model(EvalArgs)
|
||||||
launch_eval(eval_args)
|
launch_eval(eval_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
|
||||||
):
|
):
|
||||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||||
fs = get_fs(train_args_path)
|
fs = get_fs(train_args_path)
|
||||||
with fs.open(train_args_path) as f:
|
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||||
train_args = TrainArgs.model_validate_json(f.read())
|
|
||||||
|
|
||||||
if train_args.train_entropy_model:
|
if train_args.train_entropy_model:
|
||||||
model_args = train_args.entropy_model
|
model_args = train_args.entropy_model
|
||||||
|
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
|
||||||
train_args.distributed.model_dtype
|
train_args.distributed.model_dtype
|
||||||
]
|
]
|
||||||
tokenizer = train_args.data.tokenizer_args.build()
|
tokenizer = train_args.data.tokenizer_args.build()
|
||||||
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
|
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||||
|
st_dict = torch.load(f, weights_only=True)
|
||||||
model.load_state_dict(st_dict["model"])
|
model.load_state_dict(st_dict["model"])
|
||||||
model = model.cuda().eval()
|
model = model.cuda().eval()
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
|
|
|
@ -241,7 +241,9 @@ def set_preemption_flag(signum, frame):
|
||||||
preemption_flag["flag"] = True
|
preemption_flag["flag"] = True
|
||||||
|
|
||||||
|
|
||||||
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
|
def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None):
|
||||||
|
if freq < 0:
|
||||||
|
return False
|
||||||
test = train_state.step % freq == 0
|
test = train_state.step % freq == 0
|
||||||
if acc_step is not None:
|
if acc_step is not None:
|
||||||
test = test and (train_state.acc_step == acc_step)
|
test = test and (train_state.acc_step == acc_step)
|
||||||
|
|
Loading…
Reference in a new issue