diff --git a/bytelatent/args.py b/bytelatent/args.py index dd1fef5..15e44f1 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -248,8 +248,8 @@ class ValidationArgs(BaseModel): class EvalArgs(BaseModel): model_config = ConfigDict(extra="forbid") - dump_dir: str - ckpt_dir: str + dump_dir: str | None = None + ckpt_dir: str | None = None metric_log_dir: str | None = None generator: PackedCausalTransformerGeneratorArgs = ( PackedCausalTransformerGeneratorArgs() diff --git a/bytelatent/eval.py b/bytelatent/eval.py index 50e17cd..3ffe0ae 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -11,10 +11,16 @@ from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from bytelatent.args import EvalArgs, ValidationArgs +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -113,19 +119,40 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -133,16 +160,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -187,7 +209,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -206,10 +228,13 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + # results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + results = {"results": []} + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -218,6 +243,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") @@ -247,7 +273,7 @@ def launch_eval(eval_args: EvalArgs): def main(): - eval_args = parse_args(EvalArgs) + eval_args = parse_args_to_pydantic_model(EvalArgs) launch_eval(eval_args) diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): diff --git a/bytelatent/train.py b/bytelatent/train.py index ad74b44..af1c694 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -241,7 +241,9 @@ def set_preemption_flag(signum, frame): preemption_flag["flag"] = True -def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): +def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None): + if freq < 0: + return False test = train_state.step % freq == 0 if acc_step is not None: test = test and (train_state.acc_step == acc_step)