From ece82cb96064e6bf55c09f9ac28807e4e82564b1 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 12 Feb 2025 19:24:49 +0000 Subject: [PATCH] Make it possible to specify multiple config files Summary: Test Plan: Test that this iterpolates in the right order, config -> configs -> cli args ``` # All three sources python -m bytelatent.print_config config=bytelatent/configs/debug.yaml configs=[internal/configs/s3_debug.yaml] eval=null # What worked before python -m bytelatent.print_config config=internal/configs/s3_debug.yaml eval=null ``` --- bytelatent/args.py | 19 +++++++++++++++---- bytelatent/configs/debug.yaml | 4 +--- bytelatent/print_config.py | 10 ++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) create mode 100644 bytelatent/print_config.py diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..6125ae7 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -40,12 +40,23 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: def parse_args(args_cls): cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config + file_cfgs = [] + if "config" in cli_args: + file_cfg = OmegaConf.load(cli_args["config"]) + del cli_args["config"] + file_cfgs.append(file_cfg) + + if "configs" in cli_args: + for c in cli_args["configs"]: + extra_file_cfg = OmegaConf.load(c) + file_cfgs.append(extra_file_cfg) + del cli_args["configs"] default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + to_merge = [default_cfg] + to_merge.extend(file_cfgs) + to_merge.append(cli_args) + cfg = OmegaConf.merge(*to_merge) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) pydantic_args = args_cls.model_validate(cfg) return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..3fd509a --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,10 @@ +from bytelatent.args import TrainArgs, parse_args + + +def main(): + train_args = parse_args(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main()